diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index c579234fc..575537826 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -8,7 +8,6 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar; @@ -17,12 +16,10 @@ import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.EqualTo; -import org.apache.spark.sql.catalyst.expressions.Equality$; import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.Predicate$; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; @@ -48,15 +45,12 @@ import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.statement.Explain; @@ -104,7 +98,6 @@ import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; -import java.util.stream.Stream; import static java.util.Collections.emptyList; import static java.util.List.of; @@ -284,37 +277,32 @@ public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) { final String TABLE_RHS = "T2"; final UnresolvedAttribute t1Attr = new UnresolvedAttribute(seq(TABLE_LHS, APPENDCOL_ID)); final UnresolvedAttribute t2Attr = new UnresolvedAttribute(seq(TABLE_RHS, APPENDCOL_ID)); - final Seq fieldsToRemove = seq(t1Attr, t2Attr); +// final Seq fieldsToRemove = seq(t1Attr, t2Attr); + final List fieldsToRemove = new ArrayList<>(List.of(t1Attr, t2Attr)); final Node mainSearchNode = node.getChild().get(0); - - // Add a new projection layer with * and ROW_NUMBER (Main-search) - LogicalPlan leftTemp = mainSearchNode.accept(this, context); - var mainSearchWithRowNumber = getRowNumStarProjection(context, leftTemp, TABLE_LHS); - context.withSubqueryAlias(mainSearchWithRowNumber); + final Node subSearchNode = node.getSubSearch(); // Traverse to look for relation clause then append it into the sub-search. Relation relation = retrieveRelationClause(mainSearchNode); appendRelationClause(node.getSubSearch(), relation); + // Add apply a dropColumns if override present, then add * with ROW_NUMBER + LogicalPlan leftTemp = mainSearchNode.accept(this, context); +// LogicalPlan mainSearch = (node.override) +// ? new DataFrameDropColumns(getoverridedlist(subSearch), leftTemp) +// : leftTemp; + var mainSearchWithRowNumber = getRowNumStarProjection(context, leftTemp, TABLE_LHS); + context.withSubqueryAlias(mainSearchWithRowNumber); + context.apply(left -> { // Add a new projection layer with * and ROW_NUMBER (Sub-search) - LogicalPlan subSearchNode = node.getSubSearch().accept(this, context); - var subSearchWithRowNumber = getRowNumStarProjection(context, subSearchNode, TABLE_RHS); + LogicalPlan subSearch = subSearchNode.accept(this, context); + var subSearchWithRowNumber = getRowNumStarProjection(context, subSearch, TABLE_RHS); context.withSubqueryAlias(subSearchWithRowNumber); context.retainAllNamedParseExpressions(p -> p); context.retainAllPlans(p -> p); - if (node.override) { - SparkSession sparkSession = SparkSession.getActiveSession().get(); - - QueryExecution queryExecution = sparkSession.sessionState().executePlan(mainSearchWithRowNumber, CommandExecutionMode.ALL()); - QueryExecution queryExecutionSub = sparkSession.sessionState().executePlan(subSearchWithRowNumber, CommandExecutionMode.ALL()); - - Seq outputMain = queryExecution.analyzed().output(); - Seq outputSub = queryExecutionSub.analyzed().output(); - } - // Composite the join clause LogicalPlan joinedQuery = join( mainSearchWithRowNumber, subSearchWithRowNumber, @@ -322,12 +310,15 @@ public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) { Optional.of(new EqualTo(t1Attr, t2Attr)), new Join.JoinHint()); + // Remove the APPEND_ID - return new DataFrameDropColumns(fieldsToRemove, joinedQuery); + if (node.override) { + List getoverridedlist = getoverridedlist(subSearchWithRowNumber, TABLE_LHS); + fieldsToRemove.addAll(getoverridedlist); + } + return new DataFrameDropColumns(seq(fieldsToRemove), joinedQuery); }); - System.out.println("Attributes: "); - System.out.println(context.getPlan().output()); return context.getPlan(); } @@ -368,6 +359,20 @@ private static Relation retrieveRelationClause(Node node) { return null; } + private static List getoverridedlist(LogicalPlan lp, String tableName) { + // When override option present, extract fields to project from sub-search, + // then apply a dfDropColumns on main-search to avoid duplicate fields. + SparkSession sparkSession = SparkSession.getActiveSession().get(); + QueryExecution queryExecutionSub = sparkSession.sessionState() + .executePlan(lp, CommandExecutionMode.ALL()); + Seq output = queryExecutionSub.analyzed().output(); + List attributes = seqAsJavaList(output); + return attributes.stream() + .map(attr -> + new UnresolvedAttribute(seq(tableName, attr.name()))) + .collect(Collectors.toList()); + } + private org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias getRowNumStarProjection(CatalystPlanContext context, LogicalPlan lp, String alias) { SortOrder sortOrder = SortUtils.sortOrder(