Skip to content

Commit

Permalink
Adding support for Rare & Top PPL
Browse files Browse the repository at this point in the history
top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - opensearch-project#461
 - opensearch-project#536
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Aug 15, 2024
1 parent 0128e2b commit 7fc2665
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class FlintSparkPPLTopAndRareITSuite

test("create ppl rare address field query test") {
val frame = sql(s"""
| source = $testTable| rare address"
| source = $testTable| rare address
| """.stripMargin)

// Retrieve the results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
node.getChild().get(0).accept(this, context);
List<Expression> aggsExpList = visitExpressionList(node.getAggExprList(), context);
List<Expression> groupExpList = visitExpressionList(node.getGroupExprList(), context);
List<Expression> sortExpList = visitExpressionList(node.getSortExprList(), context);
if (!groupExpList.isEmpty()) {
//add group by fields to context
context.getGroupingParseExpressions().addAll(groupExpList);
Expand All @@ -199,7 +198,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
List<SortDirection> sortDirections = new ArrayList<>();
sortDirections.add(node instanceof RareAggregation ? Ascending$.MODULE$ : node instanceof TopAggregation ? Descending$.MODULE$ : Ascending$.MODULE$);

if (!sortExpList.isEmpty()) {
if (!node.getSortExprList().isEmpty()) {
visitExpressionList(node.getSortExprList(), context);
Seq<SortOrder> sortElements = context.retainAllNamedParseExpressions(exp ->
new SortOrder(exp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,20 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo
@Override
public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) {
ImmutableList.Builder<UnresolvedExpression> aggListBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<UnresolvedExpression> groupListBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<UnresolvedExpression> sortListBuilder = new ImmutableList.Builder<>();
ctx.fieldList().fieldExpression().forEach(field -> {
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field),
Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER))));
String name = field.qualifiedName().getText();
Alias alias = new Alias(name, aggExpression);
Alias alias = new Alias("count("+name+")", aggExpression);
aggListBuilder.add(alias);
// group by the `field-list` as the mandatory groupBy fields
groupListBuilder.add(internalVisitExpression(field));
});
List<UnresolvedExpression> groupList =

// group by the `by-clause` as the optional groupBy fields
groupListBuilder.addAll(
Optional.ofNullable(ctx.byClause())
.map(OpenSearchPPLParser.ByClauseContext::fieldList)
.map(
Expand All @@ -297,31 +303,38 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx)
getTextInQuery(groupCtx),
internalVisitExpression(groupCtx)))
.collect(Collectors.toList()))
.orElse(emptyList());



.orElse(emptyList())
);
//build the sort fields
ctx.fieldList().fieldExpression().forEach(field -> {
sortListBuilder.add(internalVisitExpression(field));
});
TopAggregation aggregation =
new TopAggregation(
aggListBuilder.build(),
emptyList(),
groupList);
sortListBuilder.build(),
groupListBuilder.build());
return aggregation;
}

/** Rare command. */
@Override
public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) {
ImmutableList.Builder<UnresolvedExpression> aggListBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<UnresolvedExpression> groupListBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<UnresolvedExpression> sortListBuilder = new ImmutableList.Builder<>();
ctx.fieldList().fieldExpression().forEach(field -> {
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field),
Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER))));
String name = field.qualifiedName().getText();
Alias alias = new Alias(name, aggExpression);
Alias alias = new Alias("count("+name+")", aggExpression);
aggListBuilder.add(alias);
// group by the `field-list` as the mandatory groupBy fields
groupListBuilder.add(internalVisitExpression(field));
});
List<UnresolvedExpression> groupList =

// group by the `by-clause` as the optional groupBy fields
groupListBuilder.addAll(
Optional.ofNullable(ctx.byClause())
.map(OpenSearchPPLParser.ByClauseContext::fieldList)
.map(
Expand All @@ -334,7 +347,8 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct
getTextInQuery(groupCtx),
internalVisitExpression(groupCtx)))
.collect(Collectors.toList()))
.orElse(emptyList());
.orElse(emptyList())
);
//build the sort fields
ctx.fieldList().fieldExpression().forEach(field -> {
sortListBuilder.add(internalVisitExpression(field));
Expand All @@ -343,9 +357,8 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct
new RareAggregation(
aggListBuilder.build(),
sortListBuilder.build(),
groupList);
groupListBuilder.build());
return aggregation;

}

/** From clause. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,47 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite

comparePlans(expectedPlan, logPlan, false)
}

test("test count price") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source = table | stats count(price) ", false), context)
// SQL: SELECT avg(price) as avg_price FROM table
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
val tableRelation = UnresolvedRelation(Seq("table"))
val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false), "count(price)")())
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

comparePlans(expectedPlan, logPlan, false)
}

test("test count price by country") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source = table | stats count(price) by product ", false), context)
// SQL: SELECT count(price) AS count_price FROM table GROUP BY product
val star = Seq(UnresolvedStar(None))
val productField = UnresolvedAttribute("product")
val priceField = UnresolvedAttribute("price")
val tableRelation = UnresolvedRelation(Seq("table"))

val groupByAttributes = Seq(Alias(productField, "product")())
val aggregateExpressions =
Alias(UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false), "count(price)")()
val productAlias = Alias(productField, "product")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

comparePlans(expectedPlan, logPlan, false)
}

test("test average price with Alias") {
// if successful build ppl logical plan and translate to catalyst logical plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
Expand All @@ -29,149 +29,25 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=accounts | rare gender", false), context)
val genderField = UnresolvedAttribute("gender")
val tableRelation = UnresolvedRelation(Seq("accounts"))

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("accounts")))
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with escaped table name") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=`table`", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with schema.table and no explicit fields (defaults to all fields)") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=schema.table", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
comparePlans(expectedPlan, logPlan, false)

}

test("test simple search with schema.table and one field projected") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=schema.table | fields A", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with only one table with one field projected") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=table | fields A", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with only one table with two fields projected") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t | fields A, B", false), context)

val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val expectedPlan = Project(projectList, table)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with one table with two fields projected sorted by one field") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=t | sort A | fields A, B", false), context)

val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
// Sort by A ascending
val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Ascending))
val sorted = Sort(sortOrder, true, table)
val expectedPlan = Project(projectList, sorted)

comparePlans(expectedPlan, logPlan, false)
}

test(
"test simple search with only one table with two fields with head (limit ) command projected") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=t | fields A, B | head 5", false), context)

val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val planWithLimit =
GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table)))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logPlan, false)
}

test(
"test simple search with only one table with two fields with head (limit ) command projected sorted by one descending field") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(pplParser, "source=t | sort - A | fields A, B | head 5", false),
context)

val table = UnresolvedRelation(Seq("t"))
val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending))
val sorted = Sort(sortOrder, true, table)
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val projectAB = Project(projectList, sorted)

val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logPlan, false)
}

test(
"Search multiple tables - translated into union call - fields expected to exist in both tables ") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(pplParser, "search source = table1, table2 | fields A, B", false),
context)

val table1 = UnresolvedRelation(Seq("table1"))
val table2 = UnresolvedRelation(Seq("table2"))

val allFields1 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val allFields2 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))

val projectedTable1 = Project(allFields1, table1)
val projectedTable2 = Project(allFields2, table2)

val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

comparePlans(expectedPlan, logPlan, false)
}

test("Search multiple tables - translated into union call with fields") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source = table1, table2 ", false), context)

val table1 = UnresolvedRelation(Seq("table1"))
val table2 = UnresolvedRelation(Seq("table2"))

val allFields1 = UnresolvedStar(None)
val allFields2 = UnresolvedStar(None)

val projectedTable1 = Project(Seq(allFields1), table1)
val projectedTable2 = Project(Seq(allFields2), table2)
val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false), "count(gender)")(),
genderField
)

val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)
val aggregatePlan =
Aggregate(Seq(genderField), aggregateExpressions, tableRelation)

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("gender"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit 7fc2665

Please sign in to comment.