Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rename PPL function #618

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLRenameITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test renamed should remove source field") {
val frame = sql(s"""
| source = $testTable | rename age as renamed_age
| """.stripMargin)

// Retrieve the results
frame.collect()
assert(frame.columns.contains("renamed_age"))
assert(frame.columns.length == 6)
val expectedColumns =
Array[String]("name", "state", "country", "year", "month", "renamed_age")
assert(frame.columns.sameElements(expectedColumns))
}

test("test renamed should report error when field is referenced by its original name") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | rename age as renamed_age | fields age
| """.stripMargin))
assert(
ex.getMessage()
.contains(" A column or function parameter with name `age` cannot be resolved"))
}

test("test single renamed field in fields command") {
val frame = sql(s"""
| source = $testTable | rename age as renamed_age | fields name, renamed_age
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test the following case:

| source = $testTable | rename age as renamed_age | fields age

Ref #618 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, the query source = $testTable | rename age as renamed_age | fields age has worked "correctly"; that is, the field age was present in the query result. I introduced corrections and updated the tests. Currently, the execution of the previously mentioned query causes an error.

| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("renamed_age"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "renamed_age")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple renamed fields in fields command") {
val frame = sql(s"""
| source = $testTable | rename name as renamed_name, country as renamed_country | fields renamed_name, age, renamed_country
| """.stripMargin)

val results: Array[Row] = frame.collect()
// results.foreach(println(_))
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "USA"),
Row("Hello", 30, "USA"),
Row("John", 25, "Canada"),
Row("Jane", 20, "Canada"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val fieldsProjectList = Seq(
UnresolvedAttribute("renamed_name"),
UnresolvedAttribute("age"),
UnresolvedAttribute("renamed_country"))
val renameProjectList =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("name"), "renamed_name")(),
Alias(UnresolvedAttribute("country"), "renamed_country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed fields without fields command") {
val frame = sql(s"""
| source = $testTable | rename age as user_age, country as user_country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(
Row("Jake", "California", 2023, 4, 70, "USA"),
Row("Hello", "New York", 2023, 4, 30, "USA"),
Row("John", "Ontario", 2023, 4, 25, "Canada"),
Row("Jane", "Quebec", 2023, 4, 20, "Canada"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
val sorted = results.sorted
val expectedSorted = expectedResults.sorted
assert(sorted.sameElements(expectedSorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList = Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("age"), "user_age")(),
Alias(UnresolvedAttribute("country"), "user_country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("age"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(seq(UnresolvedStar(None)), planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed field used in aggregation") {
val frame = sql(s"""
| source = $testTable | rename age as user_age | stats avg(user_age) by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")())
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("user_age")),
isDistinct = false),
"avg(user_age)")(),
Alias(UnresolvedAttribute("country"), "country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val aggregatePlan = Aggregate(
Seq(Alias(UnresolvedAttribute("country"), "country")()),
aggregateExpressions,
planDropColumn)
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
4 changes: 4 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | fields message, no_numbers`
- `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | stats count() by no_numbers`

**Rename**
- `source=accounts | rename email as user_email | fields id, user_email`
- `source=accounts | rename id as user_id, email as user_email | fields user_id, user_email`

_- **Limitation: Overriding existing field is unsupported:**_
- `source=accounts | grok address '%{NUMBER} %{GREEDYDATA:address}' | fields address`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ commands
| parseCommand
| patternsCommand
| lookupCommand
| renameCommand
;

searchCommand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.List;

Expand All @@ -20,10 +20,10 @@
@Getter
@RequiredArgsConstructor
public class Rename extends UnresolvedPlan {
private final List<Map> renameList;
private final List<UnresolvedExpression> renameList;
private UnresolvedPlan child;

public Rename(List<Map> renameList, UnresolvedPlan child) {
public Rename(List<UnresolvedExpression> renameList, UnresolvedPlan child) {
this.renameList = renameList;
this.child = child;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project$;
import org.apache.spark.sql.catalyst.plans.logical.*;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.execution.command.DescribeTableCommand;
import org.apache.spark.sql.execution.command.ExplainCommand;
Expand Down Expand Up @@ -72,6 +67,7 @@
import org.opensearch.sql.ast.tree.RareAggregation;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TopAggregation;
Expand Down Expand Up @@ -399,6 +395,23 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) {
return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context);
}

@Override
public LogicalPlan visitRename(Rename node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
}
List<Expression> fieldsToRemove = visitExpressionList(node.getRenameList(), context).stream()
.map(expression -> (org.apache.spark.sql.catalyst.expressions.Alias) expression)
.map(org.apache.spark.sql.catalyst.expressions.Alias::child)
.collect(Collectors.toList());
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
// build the plan with the projection step
LogicalPlan outputWithSourceColumns = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
return context.apply(p -> DataFrameDropColumns$.MODULE$.apply(seq(fieldsToRemove), outputWithSourceColumns));
}

@Override
public LogicalPlan visitEval(Eval node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Scope;
Expand Down Expand Up @@ -211,9 +210,9 @@ public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContex
ctx.renameClasue().stream()
.map(
ct ->
new Map(
internalVisitExpression(ct.orignalField),
internalVisitExpression(ct.renamedField)))
new Alias(
ct.renamedField.getText(),
internalVisitExpression(ct.orignalField)))
.collect(Collectors.toList()));
}

Expand Down
Loading
Loading