From 424fad417b27e529b64159a4c93cfe1067814c3e Mon Sep 17 00:00:00 2001 From: YANGDB Date: Sat, 9 Nov 2024 15:17:11 -0800 Subject: [PATCH] add visitFirstChild(node, context) method for the PlanVisitor for simplify node inner child access visibility Signed-off-by: YANGDB --- .../sql/ppl/CatalystPlanContext.java | 8 ++ .../sql/ppl/CatalystQueryPlanVisitor.java | 82 +++++++++---------- 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 53dc17576..5a717dd48 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -26,6 +26,7 @@ import java.util.Stack; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -188,6 +189,13 @@ public LogicalPlan reduce(BiFunction tran }).orElse(getPlan())); } + /** + * update context using the given action and node + */ + public CatalystPlanContext update(UnaryOperator action) { + return action.apply(this); + } + /** * apply for each plan with the given function * diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index db46d88b0..35264b582 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -14,13 +14,6 @@ import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; -import org.apache.spark.sql.catalyst.expressions.In$; -import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.InSubquery$; -import org.apache.spark.sql.catalyst.expressions.LessThan; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.ListQuery$; -import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; @@ -38,6 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; @@ -73,7 +67,6 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; -import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -91,6 +84,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -133,6 +127,10 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { return node.getPlan().accept(this, context); } + public LogicalPlan visitFirstChild(Node node, CatalystPlanContext context) { + return node.getChild().get(0).accept(this, context); + } + @Override public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { node.getStatement().accept(this, context); @@ -160,7 +158,7 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { @Override public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { Expression conditionExpression = visitExpression(node.getCondition(), context); Optional innerConditionExpression = context.popNamedParseExpressions(); @@ -174,8 +172,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { */ @Override public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); return context.apply( searchSide -> { LogicalPlan lookupTable = node.getLookupRelation().accept(this, context); Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context); @@ -231,8 +228,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { @Override public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); node.getSortByField() .ifPresent(sortField -> { Expression sortFieldExpression = visitExpression(sortField, context); @@ -255,7 +251,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); context.reduce((left, right) -> { visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); @@ -273,7 +269,7 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex @Override public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); Optional joinCondition = node.getJoinCondition() @@ -286,7 +282,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { @Override public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { var alias = org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$.MODULE$.apply(node.getAlias(), p); context.withSubqueryAlias(alias); @@ -297,7 +293,7 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co @Override public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); if (!groupExpList.isEmpty()) { @@ -343,7 +339,7 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { @Override public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); List partitionExpList = visitExpressionList(node.getPartExprList(), context); @@ -373,19 +369,22 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { - if (node.isExcluded()) { - List intersect = context.getProjectedFields().stream() - .filter(node.getProjectList()::contains) - .collect(Collectors.toList()); - if (!intersect.isEmpty()) { - // Fields in parent projection, but they have be excluded in child. For example, - // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" - throw new SyntaxCheckException(intersect + " can't be resolved"); + context.update((ctx) -> { + if (node.isExcluded()) { + List intersect = ctx.getProjectedFields().stream() + .filter(node.getProjectList()::contains) + .collect(Collectors.toList()); + if (!intersect.isEmpty()) { + // Fields in parent projection, but they have be excluded in child. For example, + // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" + throw new SyntaxCheckException(intersect + " can't be resolved"); + } + } else { + ctx.withProjectedFields(node.getProjectList()); } - } else { - context.withProjectedFields(node.getProjectList()); - } - LogicalPlan child = node.getChild().get(0).accept(this, context); + return ctx; + }); + LogicalPlan child = visitFirstChild(node, context); visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions @@ -406,7 +405,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { @Override public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); visitFieldList(node.getSortList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); @@ -414,20 +413,20 @@ public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { @Override public LogicalPlan visitHead(Head node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( node.getSize(), DataTypes.IntegerType), p)); } @Override public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { - fieldSummary.getChild().get(0).accept(this, context); + visitFirstChild(fieldSummary, context); return FieldSummaryTransformer.translate(fieldSummary, context); } @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { - fillNull.getChild().get(0).accept(this, context); + visitFirstChild(fillNull, context); List aliases = new ArrayList<>(); for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) { Field field = nullableFieldFill.getNullableFieldReference(); @@ -458,7 +457,7 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { - flatten.getChild().get(0).accept(this, context); + visitFirstChild(flatten, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -472,7 +471,7 @@ public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { @Override public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -508,7 +507,7 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); Expression sourceField = visitExpression(node.getSourceField(), context); ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); @@ -518,7 +517,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { @Override public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); @@ -535,7 +534,7 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aliases = new ArrayList<>(); List letExpressions = node.getExpressionList(); for (Let let : letExpressions) { @@ -549,8 +548,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { List expressionList = visitExpressionList(aliases, context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); } @Override @@ -575,7 +573,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext @Override public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List options = node.getOptions(); Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue();