From bc8e5375716b4a621b9f6982dd005a072da4137e Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 25 Sep 2024 21:10:32 +0800 Subject: [PATCH] Support Fields Minus Command (#698) * Support Fields Minus Command Signed-off-by: Lantao Jin * add limitation Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin (cherry picked from commit 919ba5ce4c2c7179c1f80508a4672ee685dbced3) --- .../spark/ppl/FlintSparkPPLBasicITSuite.scala | 130 +++++++++++++++++- ppl-spark-integration/README.md | 7 + .../sql/ppl/CatalystQueryPlanVisitor.java | 19 ++- ...lPlanBasicQueriesTranslatorTestSuite.scala | 34 +++++ 4 files changed, 181 insertions(+), 9 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index 7d51e123d..c56868a7e 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar, UnresolvedTableOrView} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.streaming.StreamTest @@ -260,4 +260,130 @@ class FlintSparkPPLBasicITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } } + + test("fields plus command") { + Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach { + case (selectFields, sortField) => + val frame = sql(s""" + | source = $testTable| fields + $selectFields | head 1 | sort $sortField + | """.stripMargin) + frame.show() + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + // Define the expected logical plan + val limitPlan: LogicalPlan = Limit(Literal(1), project) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("fields minus command") { + Seq(("state, country", "age"), ("`state`, `country`", "`age`")).foreach { + case (selectFields, sortField) => + val frame = sql(s""" + | source = $testTable| fields - $selectFields | sort - $sortField | head 1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + assert(results.length == 1) + val expectedResults: Array[Row] = Array(Row("Jake", 70, 2023, 4)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val drop = DataFrameDropColumns( + Seq(UnresolvedAttribute("state"), UnresolvedAttribute("country")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop) + val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("fields minus new field added by eval") { + val frame = sql(s""" + | source = $testTable| eval national = country, newAge = age + | | fields - state, national, newAge | sort - age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + val expectedResults: Array[Row] = Array(Row("Jake", 70, "USA", 2023, 4)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("country"), "national")(), + Alias(UnresolvedAttribute("age"), "newAge")()), + table) + val drop = DataFrameDropColumns( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("national"), + UnresolvedAttribute("newAge")), + evalProject) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop) + val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // TODO this test should work when the bug https://issues.apache.org/jira/browse/SPARK-49782 fixed. + ignore("fields minus new function expression added by eval") { + val frame = sql(s""" + | source = $testTable| eval national = lower(country), newAge = age + 1 + | | fields - state, national, newAge | sort - age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + val expectedResults: Array[Row] = Array(Row("Jake", 70, "USA", 2023, 4)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val lowerFunction = + UnresolvedFunction("lower", Seq(UnresolvedAttribute("country")), isDistinct = false) + val addFunction = + UnresolvedFunction("+", Seq(UnresolvedAttribute("age"), Literal(1)), isDistinct = false) + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias(lowerFunction, "national")(), + Alias(addFunction, "newAge")()), + table) + val drop = DataFrameDropColumns( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("national"), + UnresolvedAttribute("newAge")), + evalProject) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop) + val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index fa668041d..b6bd3b9aa 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -230,6 +230,13 @@ See the next samples of PPL queries : **Fields** - `source = table` - `source = table | fields a,b,c` + - `source = table | fields + a,b,c` + - `source = table | fields - b,c` + - `source = table | eval b1 = b | fields - b1,c` + +_- **Limitation: new field added by eval command with a function cannot be dropped in current version:**_ + - `source = table | eval b1 = b + 1 | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) + - `source = table | eval b1 = lower(b) | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) **Nested-Fields** - `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1` diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index e78be65f7..262ddcd48 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -230,19 +230,24 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + if (!node.isExcluded()) { + context.withProjectedFields(node.getProjectList()); + } LogicalPlan child = node.getChild().get(0).accept(this, context); List expressionList = visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions Seq projectList = seq(context.getNamedParseExpressions()); if (!projectList.isEmpty()) { - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - } - if (node.hasArgument()) { - Argument argument = node.getArgExprList().get(0); - //todo exclude the argument from the projected arguments list + if (node.isExcluded()) { + Seq dropList = context.retainAllNamedParseExpressions(p -> p); + // build the DataFrameDropColumns plan with drop list + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns(dropList, p)); + } else { + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } } return child; } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 131778820..2068060b6 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -267,4 +267,38 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("test fields + field list") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t | sort - A | fields + A, B | head 5", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) + val sorted = Sort(sortOrder, true, table) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val projection = Project(projectList, sorted) + + val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projection)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } + + test("test fields - field list") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t | sort - A | fields - A, B | head 5", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) + val sorted = Sort(sortOrder, true, table) + val dropList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val dropAB = DataFrameDropColumns(dropList, sorted) + + val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), dropAB)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } }