Skip to content

Commit

Permalink
Override impl
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Dec 14, 2024
1 parent d790d20 commit 816f6d6
Showing 1 changed file with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar;
Expand All @@ -17,12 +16,10 @@
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Equality$;
import org.apache.spark.sql.catalyst.expressions.Explode;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GeneratorOuter;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate$;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
Expand All @@ -48,15 +45,12 @@
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.statement.Explain;
Expand Down Expand Up @@ -104,7 +98,6 @@
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.emptyList;
import static java.util.List.of;
Expand Down Expand Up @@ -284,50 +277,48 @@ public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) {
final String TABLE_RHS = "T2";
final UnresolvedAttribute t1Attr = new UnresolvedAttribute(seq(TABLE_LHS, APPENDCOL_ID));
final UnresolvedAttribute t2Attr = new UnresolvedAttribute(seq(TABLE_RHS, APPENDCOL_ID));
final Seq<Expression> fieldsToRemove = seq(t1Attr, t2Attr);
// final Seq<Expression> fieldsToRemove = seq(t1Attr, t2Attr);
final List<Expression> fieldsToRemove = new ArrayList<>(List.of(t1Attr, t2Attr));
final Node mainSearchNode = node.getChild().get(0);

// Add a new projection layer with * and ROW_NUMBER (Main-search)
LogicalPlan leftTemp = mainSearchNode.accept(this, context);
var mainSearchWithRowNumber = getRowNumStarProjection(context, leftTemp, TABLE_LHS);
context.withSubqueryAlias(mainSearchWithRowNumber);
final Node subSearchNode = node.getSubSearch();

// Traverse to look for relation clause then append it into the sub-search.
Relation relation = retrieveRelationClause(mainSearchNode);
appendRelationClause(node.getSubSearch(), relation);

// Add apply a dropColumns if override present, then add * with ROW_NUMBER
LogicalPlan leftTemp = mainSearchNode.accept(this, context);
// LogicalPlan mainSearch = (node.override)
// ? new DataFrameDropColumns(getoverridedlist(subSearch), leftTemp)
// : leftTemp;
var mainSearchWithRowNumber = getRowNumStarProjection(context, leftTemp, TABLE_LHS);
context.withSubqueryAlias(mainSearchWithRowNumber);

context.apply(left -> {
// Add a new projection layer with * and ROW_NUMBER (Sub-search)
LogicalPlan subSearchNode = node.getSubSearch().accept(this, context);
var subSearchWithRowNumber = getRowNumStarProjection(context, subSearchNode, TABLE_RHS);
LogicalPlan subSearch = subSearchNode.accept(this, context);
var subSearchWithRowNumber = getRowNumStarProjection(context, subSearch, TABLE_RHS);
context.withSubqueryAlias(subSearchWithRowNumber);

context.retainAllNamedParseExpressions(p -> p);
context.retainAllPlans(p -> p);

if (node.override) {
SparkSession sparkSession = SparkSession.getActiveSession().get();

QueryExecution queryExecution = sparkSession.sessionState().executePlan(mainSearchWithRowNumber, CommandExecutionMode.ALL());
QueryExecution queryExecutionSub = sparkSession.sessionState().executePlan(subSearchWithRowNumber, CommandExecutionMode.ALL());

Seq<Attribute> outputMain = queryExecution.analyzed().output();
Seq<Attribute> outputSub = queryExecutionSub.analyzed().output();
}

// Composite the join clause
LogicalPlan joinedQuery = join(
mainSearchWithRowNumber, subSearchWithRowNumber,
Join.JoinType.LEFT,
Optional.of(new EqualTo(t1Attr, t2Attr)),
new Join.JoinHint());


// Remove the APPEND_ID
return new DataFrameDropColumns(fieldsToRemove, joinedQuery);
if (node.override) {
List<Expression> getoverridedlist = getoverridedlist(subSearchWithRowNumber, TABLE_LHS);
fieldsToRemove.addAll(getoverridedlist);
}
return new DataFrameDropColumns(seq(fieldsToRemove), joinedQuery);
});

System.out.println("Attributes: ");
System.out.println(context.getPlan().output());
return context.getPlan();
}

Expand Down Expand Up @@ -368,6 +359,20 @@ private static Relation retrieveRelationClause(Node node) {
return null;
}

private static List<Expression> getoverridedlist(LogicalPlan lp, String tableName) {
// When override option present, extract fields to project from sub-search,
// then apply a dfDropColumns on main-search to avoid duplicate fields.
SparkSession sparkSession = SparkSession.getActiveSession().get();
QueryExecution queryExecutionSub = sparkSession.sessionState()
.executePlan(lp, CommandExecutionMode.ALL());
Seq<Attribute> output = queryExecutionSub.analyzed().output();
List<Attribute> attributes = seqAsJavaList(output);
return attributes.stream()
.map(attr ->
new UnresolvedAttribute(seq(tableName, attr.name())))
.collect(Collectors.toList());
}

private org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias getRowNumStarProjection(CatalystPlanContext context, LogicalPlan lp, String alias) {

SortOrder sortOrder = SortUtils.sortOrder(
Expand Down

0 comments on commit 816f6d6

Please sign in to comment.