diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index 90ffe0385..6d9c3a5ab 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar, UnresolvedTableOrView} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.streaming.StreamTest @@ -295,4 +295,130 @@ class FlintSparkPPLBasicITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } } + + test("fields plus command") { + Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach { + case (selectFields, sortField) => + val frame = sql(s""" + | source = $testTable| fields + $selectFields | head 1 | sort $sortField + | """.stripMargin) + frame.show() + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + // Define the expected logical plan + val limitPlan: LogicalPlan = Limit(Literal(1), project) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("fields minus command") { + Seq(("state, country", "age"), ("`state`, `country`", "`age`")).foreach { + case (selectFields, sortField) => + val frame = sql(s""" + | source = $testTable| fields - $selectFields | sort - $sortField | head 1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + assert(results.length == 1) + val expectedResults: Array[Row] = Array(Row("Jake", 70, 2023, 4)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val drop = DataFrameDropColumns( + Seq(UnresolvedAttribute("state"), UnresolvedAttribute("country")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop) + val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("fields minus new field added by eval") { + val frame = sql(s""" + | source = $testTable| eval national = country, newAge = age + | | fields - state, national, newAge | sort - age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + val expectedResults: Array[Row] = Array(Row("Jake", 70, "USA", 2023, 4)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("country"), "national")(), + Alias(UnresolvedAttribute("age"), "newAge")()), + table) + val drop = DataFrameDropColumns( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("national"), + UnresolvedAttribute("newAge")), + evalProject) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop) + val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // TODO this test should work when the bug https://issues.apache.org/jira/browse/SPARK-49782 fixed. + ignore("fields minus new function expression added by eval") { + val frame = sql(s""" + | source = $testTable| eval national = lower(country), newAge = age + 1 + | | fields - state, national, newAge | sort - age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + val expectedResults: Array[Row] = Array(Row("Jake", 70, "USA", 2023, 4)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val lowerFunction = + UnresolvedFunction("lower", Seq(UnresolvedAttribute("country")), isDistinct = false) + val addFunction = + UnresolvedFunction("+", Seq(UnresolvedAttribute("age"), Literal(1)), isDistinct = false) + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias(lowerFunction, "national")(), + Alias(addFunction, "newAge")()), + table) + val drop = DataFrameDropColumns( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("national"), + UnresolvedAttribute("newAge")), + evalProject) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop) + val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 0c83abd97..e76d8d52d 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -230,6 +230,8 @@ See the next samples of PPL queries : **Fields** - `source = table` - `source = table | fields a,b,c` + - `source = table | fields + a,b,c` + - `source = table | fields - b,c` **Nested-Fields** - `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1` diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 2aa99cd67..94d4132cd 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -265,20 +265,24 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { - context.withProjectedFields(node.getProjectList()); + if (!node.isExcluded()) { + context.withProjectedFields(node.getProjectList()); + } LogicalPlan child = node.getChild().get(0).accept(this, context); visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions Seq projectList = seq(context.getNamedParseExpressions()); if (!projectList.isEmpty()) { - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - } - if (node.hasArgument()) { - Argument argument = node.getArgExprList().get(0); - //todo exclude the argument from the projected arguments list + if (node.isExcluded()) { + Seq dropList = context.retainAllNamedParseExpressions(p -> p); + // build the DataFrameDropColumns plan with drop list + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns(dropList, p)); + } else { + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } } return child; } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index ad3ee18ba..a5deac0f0 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -280,4 +280,38 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("test fields + field list") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t | sort - A | fields + A, B | head 5", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) + val sorted = Sort(sortOrder, true, table) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val projection = Project(projectList, sorted) + + val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projection)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } + + test("test fields - field list") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t | sort - A | fields - A, B | head 5", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) + val sorted = Sort(sortOrder, true, table) + val dropList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val dropAB = DataFrameDropColumns(dropList, sorted) + + val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), dropAB)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } }