Skip to content

Commit

Permalink
Merge branch 'main' into ppl-help-command
Browse files Browse the repository at this point in the history
  • Loading branch information
YANG-DB authored Oct 3, 2024
2 parents b8dce6d + deaa83e commit bc78c55
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLRenameITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test renamed should remove source field") {
val frame = sql(s"""
| source = $testTable | rename age as renamed_age
| """.stripMargin)

// Retrieve the results
frame.collect()
assert(frame.columns.contains("renamed_age"))
assert(frame.columns.length == 6)
val expectedColumns =
Array[String]("name", "state", "country", "year", "month", "renamed_age")
assert(frame.columns.sameElements(expectedColumns))
}

test("test renamed should report error when field is referenced by its original name") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | rename age as renamed_age | fields age
| """.stripMargin))
assert(
ex.getMessage()
.contains(" A column or function parameter with name `age` cannot be resolved"))
}

test("test single renamed field in fields command") {
val frame = sql(s"""
| source = $testTable | rename age as renamed_age | fields name, renamed_age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("renamed_age"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "renamed_age")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple renamed fields in fields command") {
val frame = sql(s"""
| source = $testTable | rename name as renamed_name, country as renamed_country | fields renamed_name, age, renamed_country
| """.stripMargin)

val results: Array[Row] = frame.collect()
// results.foreach(println(_))
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "USA"),
Row("Hello", 30, "USA"),
Row("John", 25, "Canada"),
Row("Jane", 20, "Canada"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val fieldsProjectList = Seq(
UnresolvedAttribute("renamed_name"),
UnresolvedAttribute("age"),
UnresolvedAttribute("renamed_country"))
val renameProjectList =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("name"), "renamed_name")(),
Alias(UnresolvedAttribute("country"), "renamed_country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed fields without fields command") {
val frame = sql(s"""
| source = $testTable | rename age as user_age, country as user_country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(
Row("Jake", "California", 2023, 4, 70, "USA"),
Row("Hello", "New York", 2023, 4, 30, "USA"),
Row("John", "Ontario", 2023, 4, 25, "Canada"),
Row("Jane", "Quebec", 2023, 4, 20, "Canada"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
val sorted = results.sorted
val expectedSorted = expectedResults.sorted
assert(sorted.sameElements(expectedSorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList = Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("age"), "user_age")(),
Alias(UnresolvedAttribute("country"), "user_country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("age"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(seq(UnresolvedStar(None)), planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed field used in aggregation") {
val frame = sql(s"""
| source = $testTable | rename age as user_age | stats avg(user_age) by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")())
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("user_age")),
isDistinct = false),
"avg(user_age)")(),
Alias(UnresolvedAttribute("country"), "country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val aggregatePlan = Aggregate(
Seq(Alias(UnresolvedAttribute("country"), "country")()),
aggregateExpressions,
planDropColumn)
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
4 changes: 4 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | fields message, no_numbers`
- `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | stats count() by no_numbers`

**Rename**
- `source=accounts | rename email as user_email | fields id, user_email`
- `source=accounts | rename id as user_id, email as user_email | fields user_id, user_email`

_- **Limitation: Overriding existing field is unsupported:**_
- `source=accounts | grok address '%{NUMBER} %{GREEDYDATA:address}' | fields address`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ commands
| parseCommand
| patternsCommand
| lookupCommand
| renameCommand
;

searchCommand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.List;

Expand All @@ -20,10 +20,10 @@
@Getter
@RequiredArgsConstructor
public class Rename extends UnresolvedPlan {
private final List<Map> renameList;
private final List<UnresolvedExpression> renameList;
private UnresolvedPlan child;

public Rename(List<Map> renameList, UnresolvedPlan child) {
public Rename(List<UnresolvedExpression> renameList, UnresolvedPlan child) {
this.renameList = renameList;
this.child = child;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project$;
import org.apache.spark.sql.catalyst.plans.logical.*;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.execution.command.DescribeTableCommand;
import org.apache.spark.sql.execution.command.ExplainCommand;
Expand Down Expand Up @@ -77,6 +72,7 @@
import org.opensearch.sql.ast.tree.RareAggregation;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TopAggregation;
Expand Down Expand Up @@ -422,6 +418,23 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) {
return ParseStrategy.visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context);
}

@Override
public LogicalPlan visitRename(Rename node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
}
List<Expression> fieldsToRemove = visitExpressionList(node.getRenameList(), context).stream()
.map(expression -> (org.apache.spark.sql.catalyst.expressions.Alias) expression)
.map(org.apache.spark.sql.catalyst.expressions.Alias::child)
.collect(Collectors.toList());
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
// build the plan with the projection step
LogicalPlan outputWithSourceColumns = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
return context.apply(p -> DataFrameDropColumns$.MODULE$.apply(seq(fieldsToRemove), outputWithSourceColumns));
}

@Override
public LogicalPlan visitEval(Eval node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Scope;
Expand Down Expand Up @@ -223,9 +222,9 @@ public UnresolvedPlan visitRenameClause(OpenSearchPPLParser.RenameClauseContext
ctx.renameClasue().stream()
.map(
ct ->
new Map(
internalVisitExpression(ct.orignalField),
internalVisitExpression(ct.renamedField)))
new Alias(
ct.renamedField.getText(),
internalVisitExpression(ct.orignalField)))
.collect(Collectors.toList()));
}

Expand Down
Loading

0 comments on commit bc78c55

Please sign in to comment.