diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 4ea564111..27904c59d 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -246,13 +246,15 @@ source = table | where ispresent(a) | - `source=accounts | rare gender` - `source=accounts | rare age by gender` +- `source=accounts | rare age by gender sample(50 percent)` #### **Top** [See additional command details](ppl-top-command.md) - `source=accounts | top gender` - `source=accounts | top 1 gender` -- `source=accounts | top 1 age by gender` +- `source=accounts | top 5 gender sample(50 percent)` +- `source=accounts | top 5 age by gender` #### **Parse** [See additional command details](ppl-parse-command.md) diff --git a/docs/ppl-lang/ppl-rare-command.md b/docs/ppl-lang/ppl-rare-command.md index 5645382f8..9a38c5a15 100644 --- a/docs/ppl-lang/ppl-rare-command.md +++ b/docs/ppl-lang/ppl-rare-command.md @@ -6,10 +6,11 @@ Using ``rare`` command to find the least common tuple of values of all fields in **Note**: A maximum of 10 results is returned for each distinct tuple of values of the group-by fields. **Syntax** -`rare [by-clause]` +`rare [by-clause] [sample(? percent)]` * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. +* sample: optional. allows reducing the amount of fields being scanned using table sample strategy favour velocity over precision ### Example 1: Find the least common values in a field @@ -44,3 +45,10 @@ PPL query: | M | 33 | | M | 36 | +----------+-------+ + +### Example 3: Find the least common values using 50 % sampling strategy + +PPL query: + + os> source=accounts | rare age sample(50 percent); + fetched rows / total rows = 2/4 diff --git a/docs/ppl-lang/ppl-top-command.md b/docs/ppl-lang/ppl-top-command.md index 4ba56f692..f92acea39 100644 --- a/docs/ppl-lang/ppl-top-command.md +++ b/docs/ppl-lang/ppl-top-command.md @@ -5,11 +5,12 @@ Using ``top`` command to find the most common tuple of values of all fields in t ### Syntax -`top [N] [by-clause]` +`top [N] [by-clause] [sample(? percent)]` * N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. +* sample: optional. allows reducing the amount of fields being scanned using table sample strategy favour velocity over precision ### Example 1: Find the most common values in a field @@ -56,3 +57,12 @@ PPL query: | M | 32 | +----------+-------+ +## Example 2: Find the most common values organized by gender using sample strategy + +The example finds most common age of all the accounts group by gender sample only 50 % of rows. + +PPL query: + + os> source=accounts | top 1 age by gender sample(50 percent); + fetched rows / total rows = 1/2 + diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index f10b6e2f5..e65a1384c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -84,6 +84,52 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl rare address field query test sample 75 %") { + val frame = sql(s""" + | source = $testTable| rare address sample(75 percent) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver") + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + Sample(0, 0.75, withReplacement = false, 0, table)) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl rare address by age field query test") { val frame = sql(s""" | source = $testTable| rare address by age @@ -111,11 +157,58 @@ class FlintSparkPPLTopAndRareITSuite "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregatePlan = + Aggregate(Seq(addressField, ageAlias), aggregateExpressions, table) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl rare address by age field query test sample 75 %") { + val frame = sql(s""" + | source = $testTable| rare address by age sample(75 percent) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 4) + + val expectedRow = Row(1, "Vancouver", 60) + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val aggregatePlan = Aggregate( Seq(addressField, ageAlias), aggregateExpressions, - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + Sample(0, 0.75, withReplacement = false, 0, table)) val sortedPlan: LogicalPlan = Sort( @@ -226,6 +319,46 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top 3 countries query test sample 75 %") { + val frame = sql(s""" + | source = $newTestTable| top 3 country sample(75 percent) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + Sample(0, 0.75, withReplacement = false, 0, table)) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation @@ -254,11 +387,53 @@ class FlintSparkPPLTopAndRareITSuite UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), "count_country")() val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")) + val aggregatePlan = + Aggregate(Seq(countryField, occupationFieldAlias), aggregateExpressions, table) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } + + test("create ppl top 2 countries by occupation field query test sample 85 %") { + val frame = sql(s""" + | source = $newTestTable| top 3 country by occupation sample(85 percent) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")) val aggregatePlan = Aggregate( Seq(countryField, occupationFieldAlias), aggregateExpressions, - UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + Sample(0, 0.85, withReplacement = false, 0, table)) val sortedPlan: LogicalPlan = Sort( diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 2c3344b3c..993654752 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -24,6 +24,7 @@ SORT: 'SORT'; EVAL: 'EVAL'; HEAD: 'HEAD'; TOP: 'TOP'; +SAMPLE: 'SAMPLE'; RARE: 'RARE'; PARSE: 'PARSE'; METHOD: 'METHOD'; @@ -79,6 +80,7 @@ DESC: 'DESC'; DATASOURCES: 'DATASOURCES'; USING: 'USING'; WITH: 'WITH'; +PERCENT: 'PERCENT'; // FIELD KEYWORDS AUTO: 'AUTO'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 1cfd172f7..534569cba 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -179,12 +179,16 @@ headCommand : HEAD (number = integerLiteral)? (FROM from = integerLiteral)? ; +sampleClause + : SAMPLE '(' (percentage = integerLiteral PERCENT ) ')' + ; + topCommand - : TOP (number = integerLiteral)? fieldList (byClause)? + : TOP (number = integerLiteral)? fieldList (byClause)? (sampleClause)? ; rareCommand - : RARE fieldList (byClause)? + : RARE fieldList (byClause)? (sampleClause)? ; grokCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java index 4f4824fb8..572ece7a5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java @@ -6,6 +6,7 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; @@ -16,6 +17,7 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; /** Logical plan node of Aggregation, the interface for building aggregation actions in queries. */ @Getter @@ -29,7 +31,8 @@ public class Aggregation extends UnresolvedPlan { private UnresolvedExpression span; private List argExprList; private UnresolvedPlan child; - + private Optional sample = Optional.empty(); + /** Aggregation Constructor without span and argument. */ public Aggregation( List aggExprList, @@ -71,4 +74,14 @@ public List getChild() { public T accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitAggregation(this, context); } + + @Getter + @Setter + @ToString + @EqualsAndHashCode(callSuper = false) + @AllArgsConstructor + public static class TablesampleContext { + public int percentage; + } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 53dc17576..8781e84df 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -11,6 +11,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Sample; import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.types.Metadata; @@ -21,8 +22,10 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.Stack; import java.util.function.BiFunction; import java.util.function.Function; @@ -35,6 +38,7 @@ /** * The context used for Catalyst logical plan. + * A query which translates into multiple plans (sub-query / join-subQuery / scala-subQuery) will have multiple contexts */ public class CatalystPlanContext { /** @@ -57,6 +61,10 @@ public class CatalystPlanContext { * The current traversal context the visitor is going threw */ private Stack planTraversalContext = new Stack<>(); + /** + * indicate this plan has to sample the relation rather than take the entire data + */ + public Optional samplePercentage = Optional.empty(); /** * NamedExpression contextual parameters @@ -106,6 +114,18 @@ public LogicalPlan define(Expression symbol) { return getPlan(); } + /** + * indicate this plan context is using table sampling + */ + public CatalystPlanContext withSamplePercentage(int percentage) { + this.samplePercentage = Optional.of(percentage); + return this; + } + + public Optional getSamplePercentage() { + return this.samplePercentage; + } + /** * append relation to relations list * @@ -117,6 +137,17 @@ public LogicalPlan withRelation(UnresolvedRelation relation) { return with(relation); } + /** + * append sample-relation to relations list + * + * @param sampleRelation + * @return + */ + public LogicalPlan withSampleRelation(Sample sampleRelation) { + this.relations.add(sampleRelation.child()); + return with(sampleRelation); + } + public void withSubqueryAlias(SubqueryAlias subqueryAlias) { this.subqueryAlias.add(subqueryAlias); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index d2ee46ae6..0028ba98e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -14,13 +14,6 @@ import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; -import org.apache.spark.sql.catalyst.expressions.In$; -import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.InSubquery$; -import org.apache.spark.sql.catalyst.expressions.LessThan; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.ListQuery$; -import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; @@ -31,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Project$; +import org.apache.spark.sql.catalyst.plans.logical.Sample; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; @@ -38,6 +32,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; @@ -123,7 +118,16 @@ public CatalystQueryPlanVisitor() { public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } - + + /** + * visit first child of the given node + * + * @return + */ + private LogicalPlan visitChild(Node node, CatalystPlanContext context) { + return node.getChild().get(0).accept(this, context); + } + /** * Handle Query Statement. */ @@ -140,26 +144,36 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + //relation has no visit child method call since its the down most element in the AST tree if (node instanceof DescribeRelation) { - TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName()); - return context.with( - new DescribeTableCommand( - identifier, - scala.collection.immutable.Map$.MODULE$.empty(), - true, - DescribeRelation$.MODULE$.getOutputAttrs())); + return visitDescribeRelation((DescribeRelation)node, context); } //regular sql algebraic relations - node.getQualifiedNames().forEach(q -> - // Resolving the qualifiedName which is composed of a datasource.schema.table - context.withRelation(new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false)) - ); + node.getQualifiedNames().forEach(q -> { + // Resolving the qualifiedName which is composed of a datasource.schema.table + UnresolvedRelation relation = new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false); + if(context.getSamplePercentage().isPresent()) { + context.withSampleRelation(new Sample(0, (double)context.getSamplePercentage().get() / 100, false, 0, relation)); + } else { + context.withRelation(relation); + } + }); return context.getPlan(); } + private static LogicalPlan visitDescribeRelation(DescribeRelation node, CatalystPlanContext context) { + TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName()); + return context.with( + new DescribeTableCommand( + identifier, + scala.collection.immutable.Map$.MODULE$.empty(), + true, + DescribeRelation$.MODULE$.getOutputAttrs())); + } + @Override public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); return context.apply(p -> { Expression conditionExpression = visitExpression(node.getCondition(), context); Optional innerConditionExpression = context.popNamedParseExpressions(); @@ -173,8 +187,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { */ @Override public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitChild(node, context); return context.apply( searchSide -> { LogicalPlan lookupTable = node.getLookupRelation().accept(this, context); Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context); @@ -230,8 +243,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { @Override public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitChild(node, context); node.getSortByField() .ifPresent(sortField -> { Expression sortFieldExpression = visitExpression(sortField, context); @@ -254,7 +266,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); context.reduce((left, right) -> { visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); @@ -272,7 +284,7 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex @Override public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); Optional joinCondition = node.getJoinCondition() @@ -285,7 +297,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { @Override public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); return context.apply(p -> { var alias = org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$.MODULE$.apply(node.getAlias(), p); context.withSubqueryAlias(alias); @@ -296,7 +308,11 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co @Override public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + //add sample context (if exists) to the plan context + if(node.getSample().isPresent()) { + context.withSamplePercentage(node.getSample().get().getPercentage()); + } + visitChild(node, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); if (!groupExpList.isEmpty()) { @@ -342,7 +358,7 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { @Override public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); List partitionExpList = visitExpressionList(node.getPartExprList(), context); @@ -384,7 +400,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { } else { context.withProjectedFields(node.getProjectList()); } - LogicalPlan child = node.getChild().get(0).accept(this, context); + LogicalPlan child = visitChild(node, context); visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions @@ -405,7 +421,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { @Override public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); visitFieldList(node.getSortList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); @@ -413,20 +429,20 @@ public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { @Override public LogicalPlan visitHead(Head node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( node.getSize(), DataTypes.IntegerType), p)); } @Override public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { - fieldSummary.getChild().get(0).accept(this, context); + visitChild(fieldSummary, context); return FieldSummaryTransformer.translate(fieldSummary, context); } @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { - fillNull.getChild().get(0).accept(this, context); + visitChild(fillNull, context); List aliases = new ArrayList<>(); for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) { Field field = nullableFieldFill.getNullableFieldReference(); @@ -457,7 +473,7 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { - flatten.getChild().get(0).accept(this, context); + visitChild(flatten, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -507,7 +523,7 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitChild(node, context); Expression sourceField = visitExpression(node.getSourceField(), context); ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); @@ -517,7 +533,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { @Override public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); @@ -534,7 +550,7 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitChild(node,context); List aliases = new ArrayList<>(); List letExpressions = node.getExpressionList(); for (Let let : letExpressions) { @@ -548,8 +564,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { List expressionList = visitExpressionList(aliases, context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); } @Override @@ -574,7 +589,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext @Override public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitChild(node, context); List options = node.getOptions(); Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index f6581016f..2d106df4b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.Aggregation.TablesampleContext; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; @@ -465,6 +466,10 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); + if(ctx.sampleClause()!=null) { + int percentage = Integer.parseInt(ctx.sampleClause().percentage.getText()); + aggregation.setSample(Optional.of(new TablesampleContext(percentage))); + } return aggregation; } @@ -510,6 +515,10 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); + if(ctx.sampleClause()!=null) { + int percentage = Integer.parseInt(ctx.sampleClause().percentage.getText()); + aggregation.setSample(Optional.of(new TablesampleContext(percentage))); + } return aggregation; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java index 1dc7b9878..67ebe8426 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -13,6 +13,7 @@ import java.util.List; import java.util.Optional; +import java.util.Set; public interface RelationUtils { /** diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 792a2dee6..b18c205b1 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -98,6 +98,47 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("test simple rare command with a by field test percentage") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logicalPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | rare address by age sample(50 percent)"), + context) + // Retrieve the logical plan + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + Sample(0, 0.5, withReplacement = false, 0, UnresolvedRelation(Seq("accounts")))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("test simple top command with a single field") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext @@ -131,6 +172,44 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test simple top command with a single field sample(50 percent) ") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | top address sample(50 percent)"), + context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test simple top 1 command by age field") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext @@ -168,6 +247,45 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test simple top 5 command by age field sample(25 percent)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | top 5 address by age sample(25 percent)"), + context) + + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + Sample(0, 0.25, withReplacement = false, 0, UnresolvedRelation(Seq("accounts")))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(5), LocalLimit(Literal(5), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, false) + } + test("create ppl top 3 countries by occupation field query test") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext @@ -207,4 +325,44 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("create ppl top 3 countries by occupation field query test with sample(25 percent)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | top 3 country by occupation sample(25 percent)"), + context) + + val tableRelation = UnresolvedRelation(Seq("accounts")) + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + Sample(0, 0.25, withReplacement = false, 0, tableRelation)) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + }