Skip to content

Commit

Permalink
add additional support for rare & top commands options including …
Browse files Browse the repository at this point in the history
…top N ...

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Aug 16, 2024
1 parent ccf087a commit 036590a
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,54 @@ class FlintSparkPPLTopAndRareITSuite
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl top 3 countries by occupation field query test") {
val newTestTable = "spark_catalog.default.new_flint_ppl_test"
createOccupationTable(newTestTable)

val frame = sql(s"""
| source = $newTestTable| top 3 country by occupation
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRows = Set(Row(1, "Canada", "Doctor"), Row(1, "Canada", "Scientist"), Row(1, "Canada", "Unemployed"))
val actualRows = results.take(3).toSet

// Compare the sets
assert(
actualRows == expectedRows,
s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val countryField = UnresolvedAttribute("country")
val occupationField = UnresolvedAttribute("occupation")
val occupationFieldAlias = Alias(occupationField, "occupation")()

val countExpr = Alias(UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), "count(country)")()
val aggregateExpressions = Seq(
countExpr,
countryField,
occupationFieldAlias)
val aggregatePlan =
Aggregate(
Seq(countryField, occupationFieldAlias),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")))

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)

val planWithLimit =
GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logicalPlan, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */
public class TopAggregation extends Aggregation {
private final Optional<Literal> results;

/** Aggregation Constructor without span and argument. */
public TopAggregation(
Optional<Literal> results,
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
this.results = results;
}

public Optional<Literal> getResults() {
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek());
}
// build the aggregation logical step
// context.apply(p -> extractedAggregation(context)); TODO remove
LogicalPlan logicalPlan = extractedAggregation(context);

// set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc)
Expand All @@ -207,6 +206,11 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
seq(new ArrayList<Expression>())));
context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan));
}
//visit TopAggregation results limit
if((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p ->(LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p));
}
return logicalPlan;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,10 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx)
ctx.fieldList().fieldExpression().forEach(field -> {
sortListBuilder.add(internalVisitExpression(field));
});
UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null);
TopAggregation aggregation =
new TopAggregation(
Optional.ofNullable((Literal) unresolvedPlan),
aggListBuilder.build(),
sortListBuilder.build(),
groupListBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ package org.opensearch.flint.spark.ppl
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.execution.command.DescribeTableCommand

class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite
Expand Down Expand Up @@ -114,4 +113,37 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple top 1 command by age field") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=accounts | top 1 address by age", false), context)

val addressField = UnresolvedAttribute("address")
val ageField = UnresolvedAttribute("age")
val ageAlias = Alias(ageField, "age")()

val countExpr = Alias(UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), "count(address)")()
val aggregateExpressions = Seq(
countExpr,
addressField,
ageAlias)
val aggregatePlan =
Aggregate(
Seq(addressField, ageAlias),
aggregateExpressions,
UnresolvedRelation(Seq("accounts")))

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Ascending)),
global = true,
aggregatePlan)

val planWithLimit =
GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit 036590a

Please sign in to comment.