diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 561b5b27b..a270f75b8 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -298,7 +298,11 @@ source = table | where ispresent(a) | - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` - +- `source = table1 | inner join on table1.a = table2.a table2 | fields table1.a, table2.a, table1.b, table1.c` (directly refer table name) +- `source = table1 | inner join on a = c table2 | fields a, b, c, d` (ignore side aliases as long as no ambiguous) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields l.a, r.a` (side alias overrides table alias) +- `source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields t1.a, t2.a` (error, side alias overrides table alias) +- `source = table1 | join left = l right = r on l.a = r.a [ source = table2 ] as s | fields l.a, s.a` (error, side alias overrides subquery alias) #### **Lookup** [See additional command details](ppl-lookup-command.md) diff --git a/docs/ppl-lang/ppl-join-command.md b/docs/ppl-lang/ppl-join-command.md index 525373f7c..b374bce5f 100644 --- a/docs/ppl-lang/ppl-join-command.md +++ b/docs/ppl-lang/ppl-join-command.md @@ -65,8 +65,8 @@ WHERE t1.serviceName = `order` SEARCH source= | | [joinType] JOIN - leftAlias - rightAlias + [leftAlias] + [rightAlias] [joinHints] ON joinCriteria @@ -79,12 +79,12 @@ SEARCH source= **leftAlias** - Syntax: `left = ` -- Required +- Optional - Description: The subquery alias to use with the left join side, to avoid ambiguous naming. **rightAlias** - Syntax: `right = ` -- Required +- Optional - Description: The subquery alias to use with the right join side, to avoid ambiguous naming. **joinHints** @@ -138,11 +138,11 @@ Rewritten by PPL Join query: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o - ON c.c_custkey = o.o_custkey AND o_comment NOT LIKE '%unusual%packages%' +| LEFT OUTER JOIN + ON c_custkey = o_custkey AND o_comment NOT LIKE '%unusual%packages%' orders -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` _- **Limitation: sub-searches is unsupported in join right side**_ @@ -151,14 +151,15 @@ If sub-searches is supported, above ppl query could be rewritten as: ```sql SEARCH source=customer | FIELDS c_custkey -| LEFT OUTER JOIN left = c, right = o ON c.c_custkey = o.o_custkey +| LEFT OUTER JOIN + ON c_custkey = o_custkey [ SEARCH source=orders | WHERE o_comment NOT LIKE '%unusual%packages%' | FIELDS o_orderkey, o_custkey ] -| STATS count(o_orderkey) AS c_count BY c.c_custkey -| STATS count(1) AS custdist BY c_count +| STATS count(o_orderkey) AS c_count BY c_custkey +| STATS count() AS custdist BY c_count | SORT - custdist, - c_count ``` diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index 00e55d50a..3127325c8 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} @@ -924,4 +924,271 @@ class FlintSparkPPLJoinITSuite s }.size == 13) } + + test("test multiple joins without table aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN ON $testTable1.name = $testTable2.name $testTable2 + | | JOIN ON $testTable2.name = $testTable3.name $testTable3 + | | fields $testTable1.name, $testTable2.name, $testTable3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some( + EqualTo( + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name"), + UnresolvedAttribute(s"$testTable3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | fields t1.name, t2.name, t3.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello"), + Row("John", "John", "John"), + Row("David", "David", "David"), + Row("David", "David", "David"), + Row("Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name")), + joinPlan2) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + assertSameRows( + Array( + Row("Jake", "Jake", "Jake", "Jake"), + Row("Hello", "Hello", "Hello", "Hello"), + Row("John", "John", "John", "John"), + Row("David", "David", "David", "David"), + Row("David", "David", "David", "David"), + Row("Jane", "Jane", "Jane", "Jane")), + frame) + + val logicalPlan = frame.queryExecution.logical + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("check access the reference by aliases") { + var frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 as t1 + | | JOIN ON t1.name = t2.name $testTable2 as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 ] as t2 + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + + frame = sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as t2 ] + | | fields t1.name, t2.name + | """.stripMargin) + assert(frame.collect().length > 0) + } + + test("access the reference by override aliases should throw exception") { + var ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as tt ] + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 ON t1.name = t2.name [ source = $testTable2 as tt ] as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 ] as tt + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + + ex = intercept[AnalysisException](sql(s""" + | source = $testTable1 as tt + | | JOIN left = t1 ON t1.name = t2.name $testTable2 as t2 + | | fields tt.name + | """.stripMargin)) + assert(ex.getMessage.contains("`tt`.`name` cannot be resolved")) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 8bb93567b..b0dcc4dda 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -326,7 +326,7 @@ joinType ; sideAlias - : LEFT EQUAL leftAlias = ident COMMA? RIGHT EQUAL rightAlias = ident + : (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)? ; joinCriteria diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java index b513d01bf..dd9947329 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/DescribeRelation.java @@ -8,12 +8,14 @@ import lombok.ToString; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import java.util.Collections; + /** * Extend Relation to describe the table itself */ @ToString public class DescribeRelation extends Relation{ public DescribeRelation(UnresolvedExpression tableName) { - super(tableName); + super(Collections.singletonList(tableName)); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java index 89f787d34..176902911 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -25,15 +25,15 @@ public class Join extends UnresolvedPlan { private UnresolvedPlan left; private final UnresolvedPlan right; - private final String leftAlias; - private final String rightAlias; + private final Optional leftAlias; + private final Optional rightAlias; private final JoinType joinType; private final Optional joinCondition; private final JoinHint joinHint; @Override public UnresolvedPlan attach(UnresolvedPlan child) { - this.left = new SubqueryAlias(leftAlias, child); + this.left = leftAlias.isEmpty() ? child : new SubqueryAlias(leftAlias.get(), child); return this; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index 1b30a7998..483b61b02 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -10,19 +10,17 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; /** Logical plan node of Relation, the interface for building the searching sources. */ -@AllArgsConstructor @ToString +@Getter @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Relation extends UnresolvedPlan { @@ -30,27 +28,6 @@ public class Relation extends UnresolvedPlan { private final List tableName; - public Relation(UnresolvedExpression tableName) { - this(tableName, null); - } - - public Relation(UnresolvedExpression tableName, String alias) { - this.tableName = Arrays.asList(tableName); - this.alias = alias; - } - - /** Optional alias name for the relation. */ - @Setter @Getter private String alias; - - /** - * Return table name. - * - * @return table name - */ - public List getTableName() { - return tableName.stream().map(Object::toString).collect(Collectors.toList()); - } - public List getQualifiedNames() { return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java index 29c3d4b90..ba66cca80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/SubqueryAlias.java @@ -6,19 +6,14 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; -import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import java.util.List; -import java.util.Objects; -@AllArgsConstructor @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor @ToString public class SubqueryAlias extends UnresolvedPlan { @Getter private final String alias; @@ -32,6 +27,11 @@ public SubqueryAlias(UnresolvedPlan child, String suffix) { this.child = child; } + public SubqueryAlias(String alias, UnresolvedPlan child) { + this.alias = alias; + this.child = child; + } + public List getChild() { return ImmutableList.of(child); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 571905f8a..3658591de 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -94,6 +94,11 @@ public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext c return unresolved.accept(this, context); } + /** This method is only for analyze the join condition expression */ + public Expression analyzeJoinCondition(UnresolvedExpression unresolved, CatalystPlanContext context) { + return context.resolveJoinCondition(unresolved, this::analyze); + } + @Override public Expression visitLiteral(Literal node, CatalystPlanContext context) { return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( @@ -181,6 +186,11 @@ public Expression visitCompare(Compare node, CatalystPlanContext context) { @Override public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + // When the qualified name is part of join condition, for example: table1.id = table2.id + // findRelation(context.traversalContext() only returns relation table1 which cause table2.id fail to resolve + if (context.isResolvingJoinCondition()) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } List relation = findRelation(context.traversalContext()); if (!relation.isEmpty()) { Optional resolveField = resolveField(relation, node, context.getRelations()); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 61762f616..53dc17576 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl; +import lombok.Getter; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -39,19 +40,19 @@ public class CatalystPlanContext { /** * Catalyst relations list **/ - private List projectedFields = new ArrayList<>(); + @Getter private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ - private List relations = new ArrayList<>(); + @Getter private List relations = new ArrayList<>(); /** * Catalyst SubqueryAlias list **/ - private List subqueryAlias = new ArrayList<>(); + @Getter private List subqueryAlias = new ArrayList<>(); /** * Catalyst evolving logical plan **/ - private Stack planBranches = new Stack<>(); + @Getter private Stack planBranches = new Stack<>(); /** * The current traversal context the visitor is going threw */ @@ -60,28 +61,12 @@ public class CatalystPlanContext { /** * NamedExpression contextual parameters **/ - private final Stack namedParseExpressions = new Stack<>(); + @Getter private final Stack namedParseExpressions = new Stack<>(); /** * Grouping NamedExpression contextual parameters **/ - private final Stack groupingParseExpressions = new Stack<>(); - - public Stack getPlanBranches() { - return planBranches; - } - - public List getRelations() { - return relations; - } - - public List getSubqueryAlias() { - return subqueryAlias; - } - - public List getProjectedFields() { - return projectedFields; - } + @Getter private final Stack groupingParseExpressions = new Stack<>(); public LogicalPlan getPlan() { if (this.planBranches.isEmpty()) return null; @@ -101,10 +86,6 @@ public Stack traversalContext() { return planTraversalContext; } - public Stack getNamedParseExpressions() { - return namedParseExpressions; - } - public void setNamedParseExpressions(Stack namedParseExpressions) { this.namedParseExpressions.clear(); this.namedParseExpressions.addAll(namedParseExpressions); @@ -114,10 +95,6 @@ public Optional popNamedParseExpressions() { return namedParseExpressions.isEmpty() ? Optional.empty() : Optional.of(namedParseExpressions.pop()); } - public Stack getGroupingParseExpressions() { - return groupingParseExpressions; - } - /** * define new field * @@ -154,13 +131,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + public LogicalPlan applyBranches(List> plans) { plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); planBranches.remove(0); return getPlan(); - } - + } + /** * append plan with evolving plans branches * @@ -288,4 +265,21 @@ public static Optional findRelation(LogicalPlan plan) { return Optional.empty(); } + @Getter private boolean isResolvingJoinCondition = false; + + /** + * Resolve the join condition with the given function. + * A flag will be set to true ahead expression resolving, then false after resolving. + * @param expr + * @param transformFunction + * @return + */ + public Expression resolveJoinCondition( + UnresolvedExpression expr, + BiFunction transformFunction) { + isResolvingJoinCondition = true; + Expression result = transformFunction.apply(expr, this); + isResolvingJoinCondition = false; + return result; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 3ad1b95cb..093e017ce 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -278,7 +278,8 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); - Optional joinCondition = node.getJoinCondition().map(c -> visitExpression(c, context)); + Optional joinCondition = node.getJoinCondition() + .map(c -> expressionAnalyzer.analyzeJoinCondition(c, context)); context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint()); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 09db8b126..acec3c2a2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -155,14 +155,26 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct joinType = Join.JoinType.CROSS; } Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - String leftAlias = ctx.sideAlias().leftAlias.getText(); - String rightAlias = ctx.sideAlias().rightAlias.getText(); + Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty(); + Optional rightAlias = Optional.empty(); if (ctx.tableOrSubqueryClause().alias != null) { - // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous - throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText()); } + if (ctx.sideAlias().rightAlias != null) { + rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText()); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); - UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); + UnresolvedPlan right; + if (rightAlias.isPresent() + && rightRelation instanceof SubqueryAlias + && rightAlias.get().equals(((SubqueryAlias) rightRelation).getAlias())) { + right = rightRelation; + } else if (rightAlias.isPresent()) { + right = new SubqueryAlias(rightAlias.get(), rightRelation); + } else { + right = rightRelation; + } Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -370,7 +382,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo /** Lookup command */ @Override public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { - Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); + Relation lookupRelation = new Relation(Collections.singletonList(this.internalVisitExpression(ctx.tableSource()))); Lookup.OutputStrategy strategy = ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); @@ -485,9 +497,8 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return ctx.alias == null - ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) - : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); + Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation; } @Override diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 3ceff7735..f4ed397e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -271,9 +271,9 @@ class PPLLogicalPlanJoinTranslatorTestSuite pplParser, s""" | source = $testTable1 - | | inner JOIN left = l,right = r ON l.id = r.id $testTable2 - | | left JOIN left = l,right = r ON l.name = r.name $testTable3 - | | cross JOIN left = l,right = r $testTable4 + | | inner JOIN left = l right = r ON l.id = r.id $testTable2 + | | left JOIN left = l right = r ON l.name = r.name $testTable3 + | | cross JOIN left = l right = r $testTable4 | """.stripMargin) val logicalPlan = planTransformer.visit(logPlan, context) val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) @@ -443,17 +443,17 @@ class PPLLogicalPlanJoinTranslatorTestSuite s""" | source = $testTable1 | | head 10 - | | inner JOIN left = l,right = r ON l.id = r.id + | | inner JOIN left = l right = r ON l.id = r.id | [ | source = $testTable2 | | where id > 10 | ] - | | left JOIN left = l,right = r ON l.name = r.name + | | left JOIN left = l right = r ON l.name = r.name | [ | source = $testTable3 | | fields id | ] - | | cross JOIN left = l,right = r + | | cross JOIN left = l right = r | [ | source = $testTable4 | | sort id @@ -565,4 +565,284 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test multiple joins with table alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with table and subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN left = l right = r ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN left = l right = r ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN left = l right = r ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("l", SubqueryAlias("t1", table1)), + SubqueryAlias("r", SubqueryAlias("t2", table2)), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + SubqueryAlias("l", joinPlan1), + SubqueryAlias("r", SubqueryAlias("t3", table3)), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + SubqueryAlias("l", joinPlan2), + SubqueryAlias("r", SubqueryAlias("t4", table4)), + Inner, + Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins without table aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN ON table1.id = table2.id table2 + | | JOIN ON table1.id = table3.id table3 + | | JOIN ON table2.id = table4.id table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + table1, + table2, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table2.id"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + table3, + Inner, + Some(EqualTo(UnresolvedAttribute("table1.id"), UnresolvedAttribute("table3.id"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + table4, + Inner, + Some(EqualTo(UnresolvedAttribute("table2.id"), UnresolvedAttribute("table4.id"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with part subquery aliases") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = table1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name table2 + | | JOIN right = t3 ON t1.name = t3.name table3 + | | JOIN right = t4 ON t2.name = t4.name table4 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + val table3 = UnresolvedRelation(Seq("table3")) + val table4 = UnresolvedRelation(Seq("table4")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table4), + Inner, + Some(EqualTo(UnresolvedAttribute("t2.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 1") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN right = t4 ON t1.name = t4.name $testTable1 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with self join 2") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2 + | | JOIN right = t3 ON t1.name = t3.name $testTable3 + | | JOIN ON t1.name = t4.name + | [ + | source = $testTable1 + | ] as t4 + | | fields t1.name, t2.name, t3.name, t4.name + | """.stripMargin) + + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", table2), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val joinPlan2 = Join( + joinPlan1, + SubqueryAlias("t3", table3), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t3.name"))), + JoinHint.NONE) + val joinPlan3 = Join( + joinPlan2, + SubqueryAlias("t4", table1), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t4.name"))), + JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("t1.name"), + UnresolvedAttribute("t2.name"), + UnresolvedAttribute("t3.name"), + UnresolvedAttribute("t4.name")), + joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test side alias will override the subquery alias") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt + | | fields t1.name, t2.name + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val joinPlan1 = Join( + SubqueryAlias("t1", table1), + SubqueryAlias("t2", SubqueryAlias("tt", SubqueryAlias("ttt", table2))), + Inner, + Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))), + JoinHint.NONE) + val expectedPlan = + Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } }