Skip to content

Commit

Permalink
additional refactory update the limit / sort visitor functions
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Sep 26, 2023
1 parent 07529ea commit 90ba63c
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class FlintSparkPPLAggregationWithSpanITSuite
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val projectPlan = Project(star, aggregatePlan)
val expectedPlan = Limit(Literal(2), projectPlan)
val limitPlan = Limit(Literal(2), aggregatePlan)
val expectedPlan = Project(star, limitPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Expand Down Expand Up @@ -250,8 +250,8 @@ class FlintSparkPPLAggregationWithSpanITSuite
"age_span")()
val aggregatePlan =
Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table)
val projectPlan = Project(star, aggregatePlan)
val expectedPlan = Limit(Literal(3), projectPlan)
val limitPlan = Limit(Literal(3), aggregatePlan)
val expectedPlan = Project(star, limitPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Expand Down Expand Up @@ -283,13 +283,13 @@ class FlintSparkPPLAggregationWithSpanITSuite
"age_span")()
val aggregatePlan =
Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table)
val projectPlan = Project(star, aggregatePlan)
val expectedPlan = Limit(Literal(2), projectPlan)
val sortedPlan: LogicalPlan = Sort(
Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)),
global = true,
expectedPlan)
aggregatePlan)
val limitPlan = Limit(Literal(2), sortedPlan)
val expectedPlan = Project(star, limitPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ class FlintSparkPPLAggregationsITSuite

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan)
val expectedPlan = Limit(Literal(1), projectPlan)
val projectPlan = Limit(Literal(1), aggregatePlan)
val expectedPlan = Project(Seq(UnresolvedStar(None)), projectPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Expand Down Expand Up @@ -326,11 +326,11 @@ class FlintSparkPPLAggregationsITSuite

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val expectedPlan = Project(star, aggregatePlan)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan)
Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age count group by country query test ") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,11 @@ class FlintSparkPPLFiltersITSuite
GreaterThan(UnresolvedAttribute("age"), Literal(10)),
Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)

val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan)
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, filterPlan)
val expectedPlan = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
Expand Down Expand Up @@ -199,8 +197,8 @@ class FlintSparkPPLFiltersITSuite
EqualTo(UnresolvedAttribute("country"), Literal("USA")))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan))
val expectedPlan = Limit(Literal(1), projectPlan)
val limitPlan = Limit(Literal(1), Project(projectList, filterPlan))
val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
Expand Down Expand Up @@ -278,11 +276,11 @@ class FlintSparkPPLFiltersITSuite
val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan)
Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, filterPlan)
val expectedPlan = Project(projectList, sortedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple name literal equal filter query with two fields result test") {
Expand Down Expand Up @@ -394,8 +392,8 @@ class FlintSparkPPLFiltersITSuite
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val projectPlan = Project(star, aggregatePlan)
val expectedPlan = Limit(Literal(2), projectPlan)
val limitPlan = Limit(Literal(2), aggregatePlan)
val expectedPlan = Project(star, limitPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ class FlintSparkPPLITSuite
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan = Limit(
val limitPlan: LogicalPlan = Limit(
Literal(2),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))))
UnresolvedRelation(Seq("default", "flint_ppl_test")))
val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
Expand All @@ -124,14 +126,17 @@ class FlintSparkPPLITSuite

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, UnresolvedRelation(Seq("default", "flint_ppl_test")))

// Define the expected logical plan
val expectedPlan: LogicalPlan = Limit(
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), Limit(
Literal(2),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))))
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan)
sortedPlan ))

// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields result test") {
Expand Down Expand Up @@ -172,15 +177,17 @@ class FlintSparkPPLITSuite

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("default", "flint_ppl_test")))

// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("default", "flint_ppl_test")))
sortedPlan)

val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan)
// Compare the two plans
assert(sortedPlan === logicalPlan)
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields and head (limit) test") {
Expand All @@ -198,9 +205,10 @@ class FlintSparkPPLITSuite
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("default", "flint_ppl_test")))
// Define the expected logical plan
val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project))
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), limitPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query two with fields and head (limit) with sorting test") {
Expand All @@ -218,11 +226,13 @@ class FlintSparkPPLITSuite
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("default", "flint_ppl_test")))
// Define the expected logical plan
val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project))
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan)
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan);
// Compare the two plans
assert(sortedPlan === logicalPlan)
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,12 @@ class FlintSparkPPLTimeWindowITSuite
"sum(productsAmount)")()
val aggregatePlan =
Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table)
val expectedPlan = Project(star, aggregatePlan)
val sortedPlan: LogicalPlan = Sort(
Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)),
global = true,
expectedPlan)
global = true, aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl query count sales by days window and productId with sorting test") {
Expand Down Expand Up @@ -309,13 +308,13 @@ class FlintSparkPPLTimeWindowITSuite
Seq(productsId, windowExpression),
Seq(aggregateExpressions, productsId, windowExpression),
table)
val expectedPlan = Project(star, aggregatePlan)
val sortedPlan: LogicalPlan = Sort(
Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)),
global = true,
expectedPlan)
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
test("create ppl query count sales by weeks window and productId with sorting test") {
val frame = sql(s"""
Expand Down Expand Up @@ -367,13 +366,13 @@ class FlintSparkPPLTimeWindowITSuite
"sum(productsAmount)")()
val aggregatePlan =
Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table)
val expectedPlan = Project(star, aggregatePlan)
val sortedPlan: LogicalPlan = Sort(
Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)),
global = true,
expectedPlan)
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(sortedPlan) === compareByString(logicalPlan))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

ignore("create ppl simple count age by span of interval of 10 years query order by age test ") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ public class CatalystPlanContext {
**/
private Stack<LogicalPlan> planBranches = new Stack<>();

/**
* limit stands for the translation of the `head` command in PPL which transforms into a limit logical step.
* default limit -MAX_INT_VAL meaning no limit was set yet
*/
private int limit = Integer.MIN_VALUE;

/**
* NamedExpression contextual parameters
**/
Expand All @@ -43,12 +37,7 @@ public class CatalystPlanContext {
* Grouping NamedExpression contextual parameters
**/
private final Stack<org.apache.spark.sql.catalyst.expressions.Expression> groupingParseExpressions = new Stack<>();

/**
* SortOrder sort by parameters
**/
private Seq<SortOrder> sortOrders = seq(emptyList());


public LogicalPlan getPlan() {
if (this.planBranches.size() == 1) {
return planBranches.peek();
Expand All @@ -74,26 +63,12 @@ public void with(LogicalPlan plan) {
this.planBranches.push(plan);
}

public void limit(int limit) {
this.limit = limit;
}

public int getLimit() {
return limit;
}

public Seq<SortOrder> getSortOrders() {
return sortOrders;
}

public void plan(Function<LogicalPlan, LogicalPlan> transformFunction) {
public LogicalPlan plan(Function<LogicalPlan, LogicalPlan> transformFunction) {
this.planBranches.replaceAll(transformFunction::apply);
return getPlan();
}
public void sort(Seq<SortOrder> sortOrders) {
this.sortOrders = sortOrders;
}

/**

/**
* retain all expressions and clear expression stack
* @return
*/
Expand Down
Loading

0 comments on commit 90ba63c

Please sign in to comment.