Skip to content

Commit

Permalink
Support Fields Minus Command
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Sep 25, 2024
1 parent f742a65 commit f78358d
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar, UnresolvedTableOrView}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.streaming.StreamTest
Expand Down Expand Up @@ -295,4 +295,130 @@ class FlintSparkPPLBasicITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}

test("fields plus command") {
Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach {
case (selectFields, sortField) =>
val frame = sql(s"""
| source = $testTable| fields + $selectFields | head 1 | sort $sortField
| """.stripMargin)
frame.show()
val results: Array[Row] = frame.collect()
assert(results.length == 1)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}

test("fields minus command") {
Seq(("state, country", "age"), ("`state`, `country`", "`age`")).foreach {
case (selectFields, sortField) =>
val frame = sql(s"""
| source = $testTable| fields - $selectFields | sort - $sortField | head 1
| """.stripMargin)

val results: Array[Row] = frame.collect()
assert(results.length == 1)
val expectedResults: Array[Row] = Array(Row("Jake", 70, 2023, 4))
assert(results.sameElements(expectedResults))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val drop = DataFrameDropColumns(
Seq(UnresolvedAttribute("state"), UnresolvedAttribute("country")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop)
val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}

test("fields minus new field added by eval") {
val frame = sql(s"""
| source = $testTable| eval national = country, newAge = age
| | fields - state, national, newAge | sort - age | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)
val expectedResults: Array[Row] = Array(Row("Jake", 70, "USA", 2023, 4))
assert(results.sameElements(expectedResults))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val evalProject = Project(
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("country"), "national")(),
Alias(UnresolvedAttribute("age"), "newAge")()),
table)
val drop = DataFrameDropColumns(
Seq(
UnresolvedAttribute("state"),
UnresolvedAttribute("national"),
UnresolvedAttribute("newAge")),
evalProject)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop)
val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

// TODO this test should work when the bug https://issues.apache.org/jira/browse/SPARK-49782 fixed.
ignore("fields minus new function expression added by eval") {
val frame = sql(s"""
| source = $testTable| eval national = lower(country), newAge = age + 1
| | fields - state, national, newAge | sort - age | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)
val expectedResults: Array[Row] = Array(Row("Jake", 70, "USA", 2023, 4))
assert(results.sameElements(expectedResults))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val lowerFunction =
UnresolvedFunction("lower", Seq(UnresolvedAttribute("country")), isDistinct = false)
val addFunction =
UnresolvedFunction("+", Seq(UnresolvedAttribute("age"), Literal(1)), isDistinct = false)
val evalProject = Project(
Seq(
UnresolvedStar(None),
Alias(lowerFunction, "national")(),
Alias(addFunction, "newAge")()),
table)
val drop = DataFrameDropColumns(
Seq(
UnresolvedAttribute("state"),
UnresolvedAttribute("national"),
UnresolvedAttribute("newAge")),
evalProject)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, drop)
val limitPlan: LogicalPlan = Limit(Literal(1), sortedPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
2 changes: 2 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ See the next samples of PPL queries :
**Fields**
- `source = table`
- `source = table | fields a,b,c`
- `source = table | fields + a,b,c`
- `source = table | fields - b,c`

**Nested-Fields**
- `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,20 +265,24 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) {

@Override
public LogicalPlan visitProject(Project node, CatalystPlanContext context) {
context.withProjectedFields(node.getProjectList());
if (!node.isExcluded()) {
context.withProjectedFields(node.getProjectList());
}
LogicalPlan child = node.getChild().get(0).accept(this, context);
visitExpressionList(node.getProjectList(), context);

// Create a projection list from the existing expressions
Seq<?> projectList = seq(context.getNamedParseExpressions());
if (!projectList.isEmpty()) {
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
// build the plan with the projection step
child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
}
if (node.hasArgument()) {
Argument argument = node.getArgExprList().get(0);
//todo exclude the argument from the projected arguments list
if (node.isExcluded()) {
Seq<Expression> dropList = context.retainAllNamedParseExpressions(p -> p);
// build the DataFrameDropColumns plan with drop list
child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns(dropList, p));
} else {
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
// build the plan with the projection step
child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
}
}
return child;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,38 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

comparePlans(expectedPlan, logPlan, false)
}

test("test fields + field list") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(pplParser, "source=t | sort - A | fields + A, B | head 5", false),
context)

val table = UnresolvedRelation(Seq("t"))
val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending))
val sorted = Sort(sortOrder, true, table)
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val projection = Project(projectList, sorted)

val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projection))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logPlan, false)
}

test("test fields - field list") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(pplParser, "source=t | sort - A | fields - A, B | head 5", false),
context)

val table = UnresolvedRelation(Seq("t"))
val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending))
val sorted = Sort(sortOrder, true, table)
val dropList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val dropAB = DataFrameDropColumns(dropList, sorted)

val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), dropAB))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit f78358d

Please sign in to comment.