diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala new file mode 100644 index 000000000..9859a552e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLRenameITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test renamed should remove source field") { + val frame = sql(s""" + | source = $testTable | rename age as renamed_age + | """.stripMargin) + + // Retrieve the results + frame.collect() + assert(frame.columns.contains("renamed_age")) + assert(frame.columns.length == 6) + val expectedColumns = + Array[String]("name", "state", "country", "year", "month", "renamed_age") + assert(frame.columns.sameElements(expectedColumns)) + } + + test("test renamed should report error when field is referenced by its original name") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | rename age as renamed_age | fields age + | """.stripMargin)) + assert( + ex.getMessage() + .contains(" A column or function parameter with name `age` cannot be resolved")) + } + + test("test single renamed field in fields command") { + val frame = sql(s""" + | source = $testTable | rename age as renamed_age | fields name, renamed_age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("renamed_age")) + val renameProjectList = + Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "renamed_age")()) + val innerProject = Project(renameProjectList, table) + val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject) + val expectedPlan = Project(fieldsProjectList, planDropColumn) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple renamed fields in fields command") { + val frame = sql(s""" + | source = $testTable | rename name as renamed_name, country as renamed_country | fields renamed_name, age, renamed_country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "USA"), + Row("Hello", 30, "USA"), + Row("John", 25, "Canada"), + Row("Jane", 20, "Canada")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("renamed_name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("renamed_country")) + val renameProjectList = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("name"), "renamed_name")(), + Alias(UnresolvedAttribute("country"), "renamed_country")()) + val innerProject = Project(renameProjectList, table) + val planDropColumn = DataFrameDropColumns( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")), + innerProject) + val expectedPlan = Project(fieldsProjectList, planDropColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test renamed fields without fields command") { + val frame = sql(s""" + | source = $testTable | rename age as user_age, country as user_country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", "California", 2023, 4, 70, "USA"), + Row("Hello", "New York", 2023, 4, 30, "USA"), + Row("John", "Ontario", 2023, 4, 25, "Canada"), + Row("Jane", "Quebec", 2023, 4, 20, "Canada")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + val sorted = results.sorted + val expectedSorted = expectedResults.sorted + assert(sorted.sameElements(expectedSorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val renameProjectList = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("age"), "user_age")(), + Alias(UnresolvedAttribute("country"), "user_country")()) + val innerProject = Project(renameProjectList, table) + val planDropColumn = DataFrameDropColumns( + Seq(UnresolvedAttribute("age"), UnresolvedAttribute("country")), + innerProject) + val expectedPlan = Project(seq(UnresolvedStar(None)), planDropColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test renamed field used in aggregation") { + val frame = sql(s""" + | source = $testTable | rename age as user_age | stats avg(user_age) by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val renameProjectList = + Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")()) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("user_age")), + isDistinct = false), + "avg(user_age)")(), + Alias(UnresolvedAttribute("country"), "country")()) + val innerProject = Project(renameProjectList, table) + val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject) + val aggregatePlan = Aggregate( + Seq(Alias(UnresolvedAttribute("country"), "country")()), + aggregateExpressions, + planDropColumn) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 02baaab45..104c193a6 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -404,6 +404,10 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | fields message, no_numbers` - `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | stats count() by no_numbers` +**Rename** +- `source=accounts | rename email as user_email | fields id, user_email` +- `source=accounts | rename id as user_id, email as user_email | fields user_id, user_email` + _- **Limitation: Overriding existing field is unsupported:**_ - `source=accounts | grok address '%{NUMBER} %{GREEDYDATA:address}' | fields address` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 30b57d5da..b4393d495 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -46,6 +46,7 @@ commands | parseCommand | patternsCommand | lookupCommand + | renameCommand ; searchCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java index 2cc2a3e73..04805ad9d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java @@ -11,7 +11,7 @@ import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; @@ -20,10 +20,10 @@ @Getter @RequiredArgsConstructor public class Rename extends UnresolvedPlan { - private final List renameList; + private final List renameList; private UnresolvedPlan child; - public Rename(List renameList, UnresolvedPlan child) { + public Rename(List renameList, UnresolvedPlan child) { this.renameList = renameList; this.child = child; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 9e00025ea..e9b9f15ba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -17,12 +17,7 @@ import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.Aggregate; -import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; -import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; -import org.apache.spark.sql.catalyst.plans.logical.Limit; -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.catalyst.plans.logical.Project$; +import org.apache.spark.sql.catalyst.plans.logical.*; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; @@ -72,6 +67,7 @@ import org.opensearch.sql.ast.tree.RareAggregation; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; @@ -399,6 +395,23 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); } + @Override + public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); + } + List fieldsToRemove = visitExpressionList(node.getRenameList(), context).stream() + .map(expression -> (org.apache.spark.sql.catalyst.expressions.Alias) expression) + .map(org.apache.spark.sql.catalyst.expressions.Alias::child) + .collect(Collectors.toList()); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + LogicalPlan outputWithSourceColumns = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return context.apply(p -> DataFrameDropColumns$.MODULE$.apply(seq(fieldsToRemove), outputWithSourceColumns)); + } + @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index e9aee3180..763cc5701 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -22,7 +22,6 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Scope; @@ -211,9 +210,9 @@ public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContex ctx.renameClasue().stream() .map( ct -> - new Map( - internalVisitExpression(ct.orignalField), - internalVisitExpression(ct.renamedField))) + new Alias( + ct.renamedField.getText(), + internalVisitExpression(ct.orignalField))) .collect(Collectors.toList())); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala new file mode 100644 index 000000000..e02c5b2c4 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project, Sort} + +class PPLLogicalPlanRenameTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test renamed fields not included in fields expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c"), + context) + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t"))) + val planDropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject) + val expectedPlan = Project(seq(UnresolvedAttribute("c")), planDropColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test renamed fields included in fields expression") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | fields r_a, r_b, c"), + context) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t"))) + val planDropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject) + val expectedPlan = Project( + seq(UnresolvedAttribute("r_a"), UnresolvedAttribute("r_b"), UnresolvedAttribute("c")), + planDropColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test renamed fields without fields command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | rename a as r_a, b as r_b"), context) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t"))) + val planDropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject) + val expectedPlan = + Project(seq(UnresolvedStar(None)), planDropColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test renamed fields with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | sort - r_a | fields r_b"), + context) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val renameProject = Project(renameProjectList, UnresolvedRelation(Seq("t"))) + val planDropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), renameProject) + val sortOrder = SortOrder(UnresolvedAttribute("r_a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, planDropColumn) + val expectedPlan = Project(seq(UnresolvedAttribute("r_b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test rename eval expression output") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = RAND() | rename a as eval_rand | fields eval_rand"), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")( + exprId = ExprId(0), + qualifier = Seq.empty)) + val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val renameProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("a"), "eval_rand")()) + val innerProject = Project(renameProjectList, evalProject) + val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("a")), innerProject) + val expectedPlan = + Project(seq(UnresolvedAttribute("eval_rand")), planDropColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}