Skip to content

Commit

Permalink
The rename command does not return the source column/field
Browse files Browse the repository at this point in the history
Signed-off-by: Lukasz Soszynski <[email protected]>
  • Loading branch information
lukasz-soszynski-eliatra committed Sep 30, 2024
1 parent cddb658 commit 1420ab3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.ppl

import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLRenameITSuite
Expand Down Expand Up @@ -38,6 +38,29 @@ class FlintSparkPPLRenameITSuite
}
}

test("test renamed should remove source field") {
val frame = sql(s"""
| source = $testTable | rename age as renamed_age
| """.stripMargin)

// Retrieve the results
frame.collect()
assert(frame.columns.contains("renamed_age"))
assert(frame.columns.length == 6)
val expectedColumns =
Array[String]("name", "state", "country", "year", "month", "renamed_age")
assert(frame.columns.sameElements(expectedColumns))
}

test("test renamed should report error when field is referenced by its original name") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | rename age as renamed_age | fields age
| """.stripMargin))
assert(
ex.getMessage()
.contains(" A column or function parameter with name `age` cannot be resolved"))
}

test("test single renamed field in fields command") {
val frame = sql(s"""
| source = $testTable | rename age as renamed_age | fields name, renamed_age
Expand All @@ -59,7 +82,9 @@ class FlintSparkPPLRenameITSuite
val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("renamed_age"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "renamed_age")())
val expectedPlan = Project(fieldsProjectList, Project(renameProjectList, table))
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand Down Expand Up @@ -91,7 +116,11 @@ class FlintSparkPPLRenameITSuite
UnresolvedStar(None),
Alias(UnresolvedAttribute("name"), "renamed_name")(),
Alias(UnresolvedAttribute("country"), "renamed_country")())
val expectedPlan = Project(fieldsProjectList, Project(renameProjectList, table))
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

Expand All @@ -102,20 +131,26 @@ class FlintSparkPPLRenameITSuite

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4, 70, "USA"),
Row("Hello", 30, "New York", "USA", 2023, 4, 30, "USA"),
Row("John", 25, "Ontario", "Canada", 2023, 4, 25, "Canada"),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, "Canada"))
Row("Jake", "California", 2023, 4, 70, "USA"),
Row("Hello", "New York", 2023, 4, 30, "USA"),
Row("John", "Ontario", 2023, 4, 25, "Canada"),
Row("Jane", "Quebec", 2023, 4, 20, "Canada"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
val sorted = results.sorted
val expectedSorted = expectedResults.sorted
assert(sorted.sameElements(expectedSorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList = Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("age"), "user_age")(),
Alias(UnresolvedAttribute("country"), "user_country")())
val expectedPlan = Project(seq(UnresolvedStar(None)), Project(renameProjectList, table))
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("age"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(seq(UnresolvedStar(None)), planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

Expand All @@ -142,10 +177,12 @@ class FlintSparkPPLRenameITSuite
isDistinct = false),
"avg(user_age)")(),
Alias(UnresolvedAttribute("country"), "country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val aggregatePlan = Aggregate(
Seq(Alias(UnresolvedAttribute("country"), "country")()),
aggregateExpressions,
Project(renameProjectList, table))
planDropColumn)
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project$;
import org.apache.spark.sql.catalyst.plans.logical.*;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.execution.command.DescribeTableCommand;
import org.apache.spark.sql.execution.command.ExplainCommand;
Expand Down Expand Up @@ -407,10 +402,14 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
}
visitExpressionList(node.getRenameList(), context);
List<Expression> fieldsToRemove = visitExpressionList(node.getRenameList(), context).stream()
.map(expression -> (org.apache.spark.sql.catalyst.expressions.Alias) expression)
.map(org.apache.spark.sql.catalyst.expressions.Alias::child)
.collect(Collectors.toList());
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
// build the plan with the projection step
return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
LogicalPlan outputWithSourceColumns = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
return context.apply(p -> DataFrameDropColumns$.MODULE$.apply(seq(fieldsToRemove), outputWithSourceColumns));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort}
import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project, Sort}

class PPLLogicalPlanRenameTranslatorTestSuite
extends SparkFunSuite
Expand All @@ -36,9 +36,10 @@ class PPLLogicalPlanRenameTranslatorTestSuite
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val expectedPlan = Project(
seq(UnresolvedAttribute("c")),
Project(renameProjectList, UnresolvedRelation(Seq("t"))))
val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t")))
val planDropColumn =
DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject)
val expectedPlan = Project(seq(UnresolvedAttribute("c")), planDropColumn)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

Expand All @@ -54,9 +55,12 @@ class PPLLogicalPlanRenameTranslatorTestSuite
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t")))
val planDropColumn =
DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject)
val expectedPlan = Project(
seq(UnresolvedAttribute("r_a"), UnresolvedAttribute("r_b"), UnresolvedAttribute("c")),
Project(renameProjectList, UnresolvedRelation(Seq("t"))))
planDropColumn)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

Expand All @@ -70,8 +74,11 @@ class PPLLogicalPlanRenameTranslatorTestSuite
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t")))
val planDropColumn =
DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject)
val expectedPlan =
Project(seq(UnresolvedStar(None)), Project(renameProjectList, UnresolvedRelation(Seq("t"))))
Project(seq(UnresolvedStar(None)), planDropColumn)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

Expand All @@ -88,8 +95,10 @@ class PPLLogicalPlanRenameTranslatorTestSuite
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val renameProject = Project(renameProjectList, UnresolvedRelation(Seq("t")))
val planDropColumn =
DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), renameProject)
val sortOrder = SortOrder(UnresolvedAttribute("r_a"), Descending, Seq.empty)
val sort = Sort(seq(sortOrder), global = true, renameProject)
val sort = Sort(seq(sortOrder), global = true, planDropColumn)
val expectedPlan = Project(seq(UnresolvedAttribute("r_b")), sort)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
Expand All @@ -109,8 +118,10 @@ class PPLLogicalPlanRenameTranslatorTestSuite
val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t")))
val renameProjectList: Seq[NamedExpression] =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("a"), "eval_rand")())
val innerProject = Project(renameProjectList, evalProject)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("a")), innerProject)
val expectedPlan =
Project(seq(UnresolvedAttribute("eval_rand")), Project(renameProjectList, evalProject))
Project(seq(UnresolvedAttribute("eval_rand")), planDropColumn)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}

0 comments on commit 1420ab3

Please sign in to comment.