diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index c9c9636be..bd60f8224 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,27 +6,18 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; -import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.RowFrame$; -import org.apache.spark.sql.catalyst.expressions.RowNumber; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; -import org.apache.spark.sql.catalyst.expressions.UnspecifiedFrame; -import org.apache.spark.sql.catalyst.expressions.UnspecifiedFrame$; -import org.apache.spark.sql.catalyst.expressions.WindowExpression; -import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; @@ -41,11 +32,9 @@ import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -import org.jetbrains.annotations.NotNull; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; -import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Compare; @@ -89,7 +78,6 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.ppl.utils.DataTypeTransformer; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; @@ -101,7 +89,6 @@ import scala.collection.Seq; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -276,11 +263,9 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { @Override public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) { - final String APPENDCOL_ID = WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME; final String TABLE_LHS = "T1"; final String TABLE_RHS = "T2"; - final List projectList = getRowNumStarProjection(context); scala.collection.mutable.Seq fieldsToRemove = seq( UnresolvedAttribute$.MODULE$.apply(TABLE_LHS + "." + APPENDCOL_ID), UnresolvedAttribute$.MODULE$.apply(TABLE_RHS + "." + APPENDCOL_ID)); @@ -288,60 +273,37 @@ public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) { new Field(QualifiedName.of(TABLE_LHS ,APPENDCOL_ID)), new Field(QualifiedName.of(TABLE_RHS, APPENDCOL_ID))); - - // Add a new projection layer with * and ROW_NUMBER (Main-search) - // Inject an addition search command into sub-search - // Add a new projection layer with * and ROW_NUMBER (Sub-search) - - // Add a new projection layer with * and ROW_NUMBER (Main-search) LogicalPlan leftTemp = node.getChild().get(0).accept(this, context); + var mainSearch = getRowNumStarProjection(context, leftTemp, TABLE_LHS); + context.withSubqueryAlias(mainSearch); - - // Add the row_number - LogicalPlan t1WithRowNumber = new org.apache.spark.sql.catalyst.plans.logical.Project(seq( - projectList), leftTemp); - - // To wrap it into T1 - var t1Table = SubqueryAlias$.MODULE$.apply(TABLE_LHS, t1WithRowNumber); - context.withSubqueryAlias(t1Table); - + // Inject an addition search command into sub-search (T2) + appendRelationClause(node.getSubSearch(), "employees"); context.apply(left -> { - // Inject an addition search command into sub-search (T2) - addSearchCmd(node.getSubSearch(), "employees"); - + // Add a new projection layer with * and ROW_NUMBER (Sub-search) LogicalPlan right = node.getSubSearch().accept(this, context); - - // Add the row_number - LogicalPlan t2WithRowNumber = new org.apache.spark.sql.catalyst.plans.logical.Project(seq( - projectList), right); - - // To wrap it into T2 - var t2Alias = SubqueryAlias$.MODULE$.apply(TABLE_RHS, t2WithRowNumber); - context.withSubqueryAlias(t2Alias); + var subSearch = getRowNumStarProjection(context, right, TABLE_RHS); + context.withSubqueryAlias(subSearch); Optional joinCondition = Optional.of(innerJoinCondition) .map(c -> expressionAnalyzer.analyzeJoinCondition(c, context)); - context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); - LogicalPlan joinedQuery = join(t1Table, t2Alias, Join.JoinType.LEFT, joinCondition, new Join.JoinHint()); + + LogicalPlan joinedQuery = join(mainSearch, subSearch, Join.JoinType.LEFT, joinCondition, new Join.JoinHint()); // Remove the APPEND_ID return new org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns(fieldsToRemove, joinedQuery); - }); - - - System.out.println(context); - +// System.out.println(context); return context.getPlan(); } - private static void addSearchCmd(Node subSearch, String relationName) { + private static void appendRelationClause(Node subSearch, String relationName) { // Till traverse till the end then append. Relation table = new Relation(of(new QualifiedName(relationName))); @@ -361,7 +323,7 @@ private static void addSearchCmd(Node subSearch, String relationName) { } } - private @NotNull List getRowNumStarProjection(CatalystPlanContext context) { + private org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias getRowNumStarProjection(CatalystPlanContext context, LogicalPlan lp, String alias) { final String DUMMY_SORT_FIELD = "1"; @@ -375,7 +337,11 @@ private static void addSearchCmd(Node subSearch, String relationName) { List projectList = (context.getNamedParseExpressions().isEmpty()) ? List.of(appendCol, UnresolvedStar$.MODULE$.apply(Option.empty())) : List.of(appendCol); - return projectList; + + LogicalPlan lpWithProjection = new org.apache.spark.sql.catalyst.plans.logical.Project(seq( + projectList), lp); + return SubqueryAlias$.MODULE$.apply(alias, lpWithProjection); + } @Override