Skip to content

Commit

Permalink
add visitFirstChild(node, context) method for the PlanVisitor for sim…
Browse files Browse the repository at this point in the history
…plify node inner child access visibility

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Nov 9, 2024
1 parent 724cbe9 commit 424fad4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Stack;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -188,6 +189,13 @@ public LogicalPlan reduce(BiFunction<LogicalPlan, LogicalPlan, LogicalPlan> tran
}).orElse(getPlan()));
}

/**
* update context using the given action and node
*/
public CatalystPlanContext update(UnaryOperator<CatalystPlanContext> action) {
return action.apply(this);
}

/**
* apply for each plan with the given function
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@
import org.apache.spark.sql.catalyst.expressions.Explode;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GeneratorOuter;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.MakeInterval$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
Expand All @@ -38,6 +31,7 @@
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.opensearch.flint.spark.FlattenGenerator;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Field;
Expand Down Expand Up @@ -73,7 +67,6 @@
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
Expand All @@ -91,6 +84,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -133,6 +127,10 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) {
return node.getPlan().accept(this, context);
}

public LogicalPlan visitFirstChild(Node node, CatalystPlanContext context) {
return node.getChild().get(0).accept(this, context);
}

@Override
public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) {
node.getStatement().accept(this, context);
Expand Down Expand Up @@ -160,7 +158,7 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {

@Override
public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
return context.apply(p -> {
Expression conditionExpression = visitExpression(node.getCondition(), context);
Optional<Expression> innerConditionExpression = context.popNamedParseExpressions();
Expand All @@ -174,8 +172,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
*/
@Override
public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);

visitFirstChild(node, context);
return context.apply( searchSide -> {
LogicalPlan lookupTable = node.getLookupRelation().accept(this, context);
Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context);
Expand Down Expand Up @@ -231,8 +228,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {

@Override
public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);

visitFirstChild(node, context);
node.getSortByField()
.ifPresent(sortField -> {
Expression sortFieldExpression = visitExpression(sortField, context);
Expand All @@ -255,7 +251,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) {

@Override
public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
context.reduce((left, right) -> {
visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context);
Seq<Expression> fields = context.retainAllNamedParseExpressions(e -> e);
Expand All @@ -273,7 +269,7 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex

@Override
public LogicalPlan visitJoin(Join node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
return context.apply(left -> {
LogicalPlan right = node.getRight().accept(this, context);
Optional<Expression> joinCondition = node.getJoinCondition()
Expand All @@ -286,7 +282,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) {

@Override
public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
return context.apply(p -> {
var alias = org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$.MODULE$.apply(node.getAlias(), p);
context.withSubqueryAlias(alias);
Expand All @@ -297,7 +293,7 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co

@Override
public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
List<Expression> aggsExpList = visitExpressionList(node.getAggExprList(), context);
List<Expression> groupExpList = visitExpressionList(node.getGroupExprList(), context);
if (!groupExpList.isEmpty()) {
Expand Down Expand Up @@ -343,7 +339,7 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) {

@Override
public LogicalPlan visitWindow(Window node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
List<Expression> windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context);
Seq<Expression> windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p);
List<Expression> partitionExpList = visitExpressionList(node.getPartExprList(), context);
Expand Down Expand Up @@ -373,19 +369,22 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) {

@Override
public LogicalPlan visitProject(Project node, CatalystPlanContext context) {
if (node.isExcluded()) {
List<UnresolvedExpression> intersect = context.getProjectedFields().stream()
.filter(node.getProjectList()::contains)
.collect(Collectors.toList());
if (!intersect.isEmpty()) {
// Fields in parent projection, but they have be excluded in child. For example,
// source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved"
throw new SyntaxCheckException(intersect + " can't be resolved");
context.update((ctx) -> {
if (node.isExcluded()) {
List<UnresolvedExpression> intersect = ctx.getProjectedFields().stream()
.filter(node.getProjectList()::contains)
.collect(Collectors.toList());
if (!intersect.isEmpty()) {
// Fields in parent projection, but they have be excluded in child. For example,
// source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved"
throw new SyntaxCheckException(intersect + " can't be resolved");
}
} else {
ctx.withProjectedFields(node.getProjectList());
}
} else {
context.withProjectedFields(node.getProjectList());
}
LogicalPlan child = node.getChild().get(0).accept(this, context);
return ctx;
});
LogicalPlan child = visitFirstChild(node, context);
visitExpressionList(node.getProjectList(), context);

// Create a projection list from the existing expressions
Expand All @@ -406,28 +405,28 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) {

@Override
public LogicalPlan visitSort(Sort node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
visitFieldList(node.getSortList(), context);
Seq<SortOrder> sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp));
return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p));
}

@Override
public LogicalPlan visitHead(Head node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
node.getSize(), DataTypes.IntegerType), p));
}

@Override
public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) {
fieldSummary.getChild().get(0).accept(this, context);
visitFirstChild(fieldSummary, context);
return FieldSummaryTransformer.translate(fieldSummary, context);
}

@Override
public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) {
fillNull.getChild().get(0).accept(this, context);
visitFirstChild(fillNull, context);
List<UnresolvedExpression> aliases = new ArrayList<>();
for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) {
Field field = nullableFieldFill.getNullableFieldReference();
Expand Down Expand Up @@ -458,7 +457,7 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context)

@Override
public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) {
flatten.getChild().get(0).accept(this, context);
visitFirstChild(flatten, context);
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
Expand All @@ -472,7 +471,7 @@ public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) {

@Override
public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
Expand Down Expand Up @@ -508,7 +507,7 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan

@Override
public LogicalPlan visitParse(Parse node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
Expression sourceField = visitExpression(node.getSourceField(), context);
ParseMethod parseMethod = node.getParseMethod();
java.util.Map<String, Literal> arguments = node.getArguments();
Expand All @@ -518,7 +517,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) {

@Override
public LogicalPlan visitRename(Rename node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
Expand All @@ -535,7 +534,7 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) {

@Override
public LogicalPlan visitEval(Eval node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
List<UnresolvedExpression> aliases = new ArrayList<>();
List<Let> letExpressions = node.getExpressionList();
for (Let let : letExpressions) {
Expand All @@ -549,8 +548,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) {
List<Expression> expressionList = visitExpressionList(aliases, context);
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
// build the plan with the projection step
child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
return child;
return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
}

@Override
Expand All @@ -575,7 +573,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext

@Override
public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
visitFirstChild(node, context);
List<Argument> options = node.getOptions();
Integer allowedDuplication = (Integer) options.get(0).getValue().getValue();
Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue();
Expand Down

0 comments on commit 424fad4

Please sign in to comment.