diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala index 1a0b490f7..f9437afae 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.ppl import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLRenameITSuite @@ -97,15 +97,15 @@ class FlintSparkPPLRenameITSuite test("test renamed fields without fields command") { val frame = sql(s""" - | source = $testTable | rename state as _state, country as _country + | source = $testTable | rename age as user_age, country as user_country | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("Jake", 70, "California", "USA", 2023, 4, "California", "USA"), - Row("Hello", 30, "New York", "USA", 2023, 4, "New York", "USA"), - Row("John", 25, "Ontario", "Canada", 2023, 4, "Ontario", "Canada"), - Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Quebec", "Canada")) + Row("Jake", 70, "California", "USA", 2023, 4, 70, "USA"), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, "USA"), + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, "Canada"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, "Canada")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -113,9 +113,40 @@ class FlintSparkPPLRenameITSuite val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val renameProjectList = Seq( UnresolvedStar(None), - Alias(UnresolvedAttribute("state"), "_state")(), - Alias(UnresolvedAttribute("country"), "_country")()) + Alias(UnresolvedAttribute("age"), "user_age")(), + Alias(UnresolvedAttribute("country"), "user_country")()) val expectedPlan = Project(seq(UnresolvedStar(None)), Project(renameProjectList, table)) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test renamed field used in aggregation") { + val frame = sql(s""" + | source = $testTable | rename age as user_age | stats avg(user_age) by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val renameProjectList = + Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")()) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("user_age")), + isDistinct = false), + "avg(user_age)")(), + Alias(UnresolvedAttribute("country"), "country")()) + val aggregatePlan = Aggregate( + Seq(Alias(UnresolvedAttribute("country"), "country")()), + aggregateExpressions, + Project(renameProjectList, table)) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala new file mode 100644 index 000000000..02781ff88 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanRenameTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test renamed fields not included in fields expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c", false), + context) + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("c")), + Project(renameProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test renamed fields included in fields expression") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | fields r_a, r_b, c", false), + context) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("r_a"), UnresolvedAttribute("r_b"), UnresolvedAttribute("c")), + Project(renameProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test renamed fields without fields command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b", false), + context) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val expectedPlan = + Project(seq(UnresolvedStar(None)), Project(renameProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test renamed fields with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | sort - r_a | fields r_b", false), + context) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "r_a")(), + Alias(UnresolvedAttribute("b"), "r_b")()) + val renameProject = Project(renameProjectList, UnresolvedRelation(Seq("t"))) + val sortOrder = SortOrder(UnresolvedAttribute("r_a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, renameProject) + val expectedPlan = Project(seq(UnresolvedAttribute("r_b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test rename eval expression output") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = RAND() | rename a as eval_rand | fields eval_rand", + false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")( + exprId = ExprId(0), + qualifier = Seq.empty)) + val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val renameProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("a"), "eval_rand")()) + val expectedPlan = + Project(seq(UnresolvedAttribute("eval_rand")), Project(renameProjectList, evalProject)) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}