diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java index a8ec28d0e..0854c1724 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.List; +import java.util.Objects; public class Field extends UnresolvedExpression { private final QualifiedName field; @@ -47,4 +48,25 @@ public List getChild() { public R accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitField(this, context); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Field field1 = (Field) o; + return Objects.equals(field, field1.field) && Objects.equals(fieldArgs, field1.fieldArgs); + } + + @Override + public int hashCode() { + return Objects.hash(field, fieldArgs); + } + + @Override + public String toString() { + return "Field(" + + "field=" + field + + ", fieldArgs=" + fieldArgs + + ')'; + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java index 8abd3a98c..133edf3ff 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java @@ -12,6 +12,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.StreamSupport; @@ -107,4 +108,17 @@ public List getChild() { public R accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitQualifiedName(this, context); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + QualifiedName that = (QualifiedName) o; + return Objects.equals(parts, that.parts); + } + + @Override + public int hashCode() { + return Objects.hashCode(parts); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 1e31e9c6f..be00796a8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -65,6 +65,7 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TopAggregation; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -238,7 +239,16 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { - if (!node.isExcluded()) { + if (node.isExcluded()) { + List intersect = context.getProjectedFields().stream() + .filter(node.getProjectList()::contains) + .collect(Collectors.toList()); + if (!intersect.isEmpty()) { + // Fields in parent projection, but they have be excluded in child. For example, + // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" + throw new SyntaxCheckException(intersect + " can't be resolved"); + } + } else { context.withProjectedFields(node.getProjectList()); } LogicalPlan child = node.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index a5deac0f0..34de86d92 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers @@ -314,4 +315,33 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) comparePlans(expectedPlan, logPlan, false) } + + test("test fields + then - field list") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t | fields + A, B, C | fields - A, B", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val projectABC = Project( + Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"), UnresolvedAttribute("C")), + table) + val dropList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val dropAB = DataFrameDropColumns(dropList, projectABC) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropAB) + comparePlans(expectedPlan, logPlan, false) + } + + test("test fields - then + field list") { + val context = new CatalystPlanContext + val thrown = intercept[SyntaxCheckException] { + planTransformer.visit( + plan(pplParser, "source=t | fields - A, B | fields + A, B, C", false), + context) + } + assert( + thrown.getMessage + === "[Field(field=A, fieldArgs=[]), Field(field=B, fieldArgs=[])] can't be resolved") + } }