From cea6e4423747f85c8fc89ac9d191052f22f22afc Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 9 Oct 2024 19:51:40 -0700 Subject: [PATCH 01/12] add support for FieldSummary - antlr syntax - ast expression builder - ast node builder - catalyst ast builder Signed-off-by: YANGDB --- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 8 ++ .../src/main/antlr4/OpenSearchPPLParser.g4 | 12 +++ .../sql/ast/AbstractNodeVisitor.java | 11 +++ .../sql/ast/expression/FieldList.java | 34 +++++++++ .../sql/ast/expression/NamedExpression.java | 30 ++++++++ .../opensearch/sql/ast/tree/FieldSummary.java | 76 +++++++++++++++++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 7 ++ .../opensearch/sql/ppl/parser/AstBuilder.java | 8 +- .../sql/ppl/parser/AstExpressionBuilder.java | 37 +++++++++ ...eldsummaryCommandTranslatorTestSuite.scala | 55 ++++++++++++++ 10 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index dd43007f4..af14d15c6 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -86,6 +86,14 @@ STR: 'STR'; IP: 'IP'; NUM: 'NUM'; + +// FIELDSUMMARY keywords +FIELDSUMMARY: 'FIELDSUMMARY'; +INCLUDEFIELDS: 'INCLUDEFIELDS'; +EXCLUDEFIELDS: 'EXCLUDEFIELDS'; +TOPVALUES: 'TOPVALUES'; +NULLS: 'NULLS'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index fb1c79bd2..6222e2a1d 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -52,6 +52,7 @@ commands | lookupCommand | renameCommand | fillnullCommand + | fieldsummaryCommand ; searchCommand @@ -59,6 +60,17 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; + +fieldsummaryCommand + : FIELDSUMMARY (fieldsummaryParameter)* + ; + +fieldsummaryParameter + : INCLUDEFIELDS EQUAL fieldList # fieldsummaryIncludeFields + | EXCLUDEFIELDS EQUAL fieldList # fieldsummaryExcludeFields + | TOPVALUES EQUAL integerLiteral # fieldsummaryTopValues + | NULLS EQUAL booleanLiteral # fieldsummaryNulls + ; describeCommand : DESCRIBE tableSourceClause diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e42306965..c02ea05b2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -16,6 +16,8 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -204,6 +206,10 @@ public T visitField(Field node, C context) { return visitChildren(node, context); } + public T visitFieldList(FieldList node, C context) { + return visitChildren(node, context); + } + public T visitQualifiedName(QualifiedName node, C context) { return visitChildren(node, context); } @@ -294,7 +300,12 @@ public T visitExplain(Explain node, C context) { public T visitInSubquery(InSubquery node, C context) { return visitChildren(node, context); } + public T visitFillNull(FillNull fillNull, C context) { return visitChildren(fillNull, context); } + + public T visitFieldSummary(FieldSummary fieldSummary, C context) { + return visitChildren(fieldSummary, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java new file mode 100644 index 000000000..4f6ac5e14 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** Expression node that includes a list of fields nodes. */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@AllArgsConstructor +public class FieldList extends UnresolvedExpression { + private final List fieldList; + + @Override + public List getChild() { + return ImmutableList.copyOf(fieldList); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldList(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java new file mode 100644 index 000000000..4fee68a09 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java @@ -0,0 +1,30 @@ +package org.opensearch.sql.ast.expression; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +@Getter +@ToString +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class NamedExpression extends UnresolvedExpression { + private final int expressionId; + private final UnresolvedExpression expression; + + // private final DataType valueType; + @Override + public List getChild() { + return Arrays.asList(expression); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visit(this, context); + } +} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java new file mode 100644 index 000000000..cf7c8dfcc --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.NamedExpression; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.EXCLUDEFIELDS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.TOPVALUES; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +public class FieldSummary extends UnresolvedPlan { + private List includeFields; + private List excludeFields; + private int topValues; + private boolean nulls; + private List collect; + private UnresolvedPlan child; + + public FieldSummary(List collect) { + this.collect = collect; + collect.stream().filter(e->e instanceof NamedExpression) + .forEach(exp -> { + switch (((NamedExpression) exp).getExpressionId()) { + case NULLS: + this.nulls = (boolean) ((Literal) exp.getChild().get(0)).getValue(); + break; + case TOPVALUES: + this.topValues = (int) ((Literal) exp.getChild().get(0)).getValue(); + break; + case EXCLUDEFIELDS: + this.excludeFields = ((FieldList) exp.getChild().get(0)).getFieldList(); + break; + case INCLUDEFIELDS: + this.includeFields = ((FieldList) exp.getChild().get(0)).getFieldList(); + break; + } + }); + } + + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldSummary(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index e6ab083ee..fd2e8a69f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -61,6 +61,7 @@ import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; @@ -391,6 +392,12 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { node.getSize(), DataTypes.IntegerType), p)); } + @Override + public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { + fieldSummary.getChild().get(0).accept(this, context); + return super.visitFieldSummary(fieldSummary, context); + } + @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { fillNull.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 8673b1582..84089d9ac 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -411,7 +412,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) groupListBuilder.build()); return aggregation; } - + + @Override + public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryCommandContext ctx) { + return new FieldSummary(ctx.fieldsummaryParameter().stream().map(arg -> expressionBuilder.visit(arg)).collect(Collectors.toList())); + } + /** Rare command. */ @Override public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index f5e9269be..0040fbc9e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; @@ -28,6 +29,7 @@ import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.NamedExpression; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -37,6 +39,7 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -48,6 +51,10 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.EXCLUDEFIELDS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.TOPVALUES; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; @@ -177,6 +184,36 @@ public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ArgumentFactory.getArgumentList(ctx)); } + @Override + public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.FieldsummaryIncludeFieldsContext ctx) { + List includeFields = ctx.fieldList().fieldExpression().stream() + .map(this::visitFieldExpression) + .map(p->(Field)p) + .collect(Collectors.toList()); + return new NamedExpression(INCLUDEFIELDS,new FieldList(includeFields)); + } + + @Override + public UnresolvedExpression visitFieldsummaryExcludeFields(OpenSearchPPLParser.FieldsummaryExcludeFieldsContext ctx) { + List excludeFields = ctx.fieldList().fieldExpression().stream() + .map(this::visitFieldExpression) + .map(p->(Field)p) + .collect(Collectors.toList()); + return new NamedExpression(EXCLUDEFIELDS,new FieldList(excludeFields)); + } + + @Override + public UnresolvedExpression visitFieldsummaryTopValues(OpenSearchPPLParser.FieldsummaryTopValuesContext ctx) { + return new NamedExpression(TOPVALUES,visitIntegerLiteral(ctx.integerLiteral())); + } + + + @Override + public UnresolvedExpression visitFieldsummaryNulls(OpenSearchPPLParser.FieldsummaryNullsContext ctx) { + return new NamedExpression(NULLS,visitBooleanLiteral(ctx.booleanLiteral())); + } + + /** * Aggregation function. */ diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..620fdcaab --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project} +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +class PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test fieldsummary with `includefields=status_code,user_id,response_time`") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= status_code, user_id, response_time topvalues=5 nulls=true"), + context) + + val table = UnresolvedRelation(Seq("t")) + + val renameProjectList: Seq[NamedExpression] = + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute("column_name"), Literal("null replacement value")), + isDistinct = false), + "column_name")()) + val renameProject = Project(renameProjectList, table) + + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) + + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} \ No newline at end of file From fd1375aeba79f47be0fd4bf67e85c166d53fb12d Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 10 Oct 2024 15:46:22 -0700 Subject: [PATCH 02/12] add support for FieldSummary - antlr syntax - ast expression builder - ast node builder - catalyst ast builder Signed-off-by: YANGDB --- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 - .../src/main/antlr4/OpenSearchPPLParser.g4 | 1 - .../opensearch/sql/ast/tree/FieldSummary.java | 10 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 3 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 914 +++++++++--------- .../sql/ppl/parser/AstExpressionBuilder.java | 9 - .../ppl/utils/FieldSummaryTransformer.java | 180 ++++ ...ldSummaryCommandTranslatorTestSuite.scala} | 2 +- 8 files changed, 663 insertions(+), 457 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java rename ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/{PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala => PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala} (97%) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index af14d15c6..a2b84d960 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -90,7 +90,6 @@ NUM: 'NUM'; // FIELDSUMMARY keywords FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; -EXCLUDEFIELDS: 'EXCLUDEFIELDS'; TOPVALUES: 'TOPVALUES'; NULLS: 'NULLS'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 6222e2a1d..c9a425c05 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -67,7 +67,6 @@ fieldsummaryCommand fieldsummaryParameter : INCLUDEFIELDS EQUAL fieldList # fieldsummaryIncludeFields - | EXCLUDEFIELDS EQUAL fieldList # fieldsummaryExcludeFields | TOPVALUES EQUAL integerLiteral # fieldsummaryTopValues | NULLS EQUAL booleanLiteral # fieldsummaryNulls ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java index cf7c8dfcc..1b6029803 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -10,6 +10,7 @@ import lombok.Getter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Literal; @@ -19,7 +20,6 @@ import java.util.List; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.EXCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.TOPVALUES; @@ -29,7 +29,6 @@ @EqualsAndHashCode(callSuper = false) public class FieldSummary extends UnresolvedPlan { private List includeFields; - private List excludeFields; private int topValues; private boolean nulls; private List collect; @@ -46,9 +45,6 @@ public FieldSummary(List collect) { case TOPVALUES: this.topValues = (int) ((Literal) exp.getChild().get(0)).getValue(); break; - case EXCLUDEFIELDS: - this.excludeFields = ((FieldList) exp.getChild().get(0)).getFieldList(); - break; case INCLUDEFIELDS: this.includeFields = ((FieldList) exp.getChild().get(0)).getFieldList(); break; @@ -58,8 +54,8 @@ public FieldSummary(List collect) { @Override - public List getChild() { - return ImmutableList.of(); + public List getChild() { + return child == null ? List.of() : List.of(child); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index fd2e8a69f..aea8256ef 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -82,6 +82,7 @@ import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; @@ -395,7 +396,7 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { @Override public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getChild().get(0).accept(this, context); - return super.visitFieldSummary(fieldSummary, context); + return FieldSummaryTransformer.translate(fieldSummary, context); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 84089d9ac..7dd757d25 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -71,493 +71,533 @@ import static org.opensearch.sql.ast.tree.FillNull.ContainNullableFieldFill.ofVariousValue; -/** Class of building the AST. Refines the visit path and build the AST nodes */ +/** + * Class of building the AST. Refines the visit path and build the AST nodes + */ public class AstBuilder extends OpenSearchPPLParserBaseVisitor { - private AstExpressionBuilder expressionBuilder; + private AstExpressionBuilder expressionBuilder; - /** - * PPL query to get original token text. This is necessary because token.getText() returns text - * without whitespaces or other characters discarded by lexer. - */ - private String query; + /** + * PPL query to get original token text. This is necessary because token.getText() returns text + * without whitespaces or other characters discarded by lexer. + */ + private String query; - public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { - this.expressionBuilder = expressionBuilder; - this.query = query; - } + public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { + this.expressionBuilder = expressionBuilder; + this.query = query; + } - @Override - public UnresolvedPlan visitQueryStatement(OpenSearchPPLParser.QueryStatementContext ctx) { - UnresolvedPlan pplCommand = visit(ctx.pplCommands()); - return ctx.commands().stream().map(this::visit).reduce(pplCommand, (r, e) -> e.attach(r)); - } + @Override + public UnresolvedPlan visitQueryStatement(OpenSearchPPLParser.QueryStatementContext ctx) { + UnresolvedPlan pplCommand = visit(ctx.pplCommands()); + return ctx.commands().stream().map(this::visit).reduce(pplCommand, (r, e) -> e.attach(r)); + } - @Override - public UnresolvedPlan visitSubSearch(OpenSearchPPLParser.SubSearchContext ctx) { - UnresolvedPlan searchCommand = visit(ctx.searchCommand()); - return ctx.commands().stream().map(this::visit).reduce(searchCommand, (r, e) -> e.attach(r)); - } + @Override + public UnresolvedPlan visitSubSearch(OpenSearchPPLParser.SubSearchContext ctx) { + UnresolvedPlan searchCommand = visit(ctx.searchCommand()); + return ctx.commands().stream().map(this::visit).reduce(searchCommand, (r, e) -> e.attach(r)); + } - /** Search command. */ - @Override - public UnresolvedPlan visitSearchFrom(OpenSearchPPLParser.SearchFromContext ctx) { - return visitFromClause(ctx.fromClause()); - } + /** + * Search command. + */ + @Override + public UnresolvedPlan visitSearchFrom(OpenSearchPPLParser.SearchFromContext ctx) { + return visitFromClause(ctx.fromClause()); + } - @Override - public UnresolvedPlan visitSearchFromFilter(OpenSearchPPLParser.SearchFromFilterContext ctx) { - return new Filter(internalVisitExpression(ctx.logicalExpression())) - .attach(visit(ctx.fromClause())); - } + @Override + public UnresolvedPlan visitSearchFromFilter(OpenSearchPPLParser.SearchFromFilterContext ctx) { + return new Filter(internalVisitExpression(ctx.logicalExpression())) + .attach(visit(ctx.fromClause())); + } - @Override - public UnresolvedPlan visitSearchFilterFrom(OpenSearchPPLParser.SearchFilterFromContext ctx) { - return new Filter(internalVisitExpression(ctx.logicalExpression())) - .attach(visit(ctx.fromClause())); - } + @Override + public UnresolvedPlan visitSearchFilterFrom(OpenSearchPPLParser.SearchFilterFromContext ctx) { + return new Filter(internalVisitExpression(ctx.logicalExpression())) + .attach(visit(ctx.fromClause())); + } - @Override - public UnresolvedPlan visitDescribeCommand(OpenSearchPPLParser.DescribeCommandContext ctx) { - final Relation table = (Relation) visitTableSourceClause(ctx.tableSourceClause()); - QualifiedName tableQualifiedName = table.getTableQualifiedName(); - ArrayList parts = new ArrayList<>(tableQualifiedName.getParts()); - return new DescribeRelation(new QualifiedName(parts)); - } + @Override + public UnresolvedPlan visitDescribeCommand(OpenSearchPPLParser.DescribeCommandContext ctx) { + final Relation table = (Relation) visitTableSourceClause(ctx.tableSourceClause()); + QualifiedName tableQualifiedName = table.getTableQualifiedName(); + ArrayList parts = new ArrayList<>(tableQualifiedName.getParts()); + return new DescribeRelation(new QualifiedName(parts)); + } - /** Where command. */ - @Override - public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext ctx) { - return new Filter(internalVisitExpression(ctx.logicalExpression())); - } + /** + * Where command. + */ + @Override + public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext ctx) { + return new Filter(internalVisitExpression(ctx.logicalExpression())); + } - @Override - public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { - return new Correlation(ctx.correlationType().getText(), - ctx.fieldList().fieldExpression().stream() - .map(OpenSearchPPLParser.FieldExpressionContext::qualifiedName) - .map(this::internalVisitExpression) - .map(u -> (QualifiedName) u) - .collect(Collectors.toList()), - Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), - expressionBuilder.visit(ctx.scopeClause().value), - SpanUnit.of(Objects.isNull(ctx.scopeClause().unit) ? "" : ctx.scopeClause().unit.getText())), - Objects.isNull(ctx.mappingList()) ? new FieldsMapping(emptyList()) : new FieldsMapping(ctx.mappingList() - .mappingClause().stream() - .map(this::internalVisitExpression) - .collect(Collectors.toList()))); - } + @Override + public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { + return new Correlation(ctx.correlationType().getText(), + ctx.fieldList().fieldExpression().stream() + .map(OpenSearchPPLParser.FieldExpressionContext::qualifiedName) + .map(this::internalVisitExpression) + .map(u -> (QualifiedName) u) + .collect(Collectors.toList()), + Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), + expressionBuilder.visit(ctx.scopeClause().value), + SpanUnit.of(Objects.isNull(ctx.scopeClause().unit) ? "" : ctx.scopeClause().unit.getText())), + Objects.isNull(ctx.mappingList()) ? new FieldsMapping(emptyList()) : new FieldsMapping(ctx.mappingList() + .mappingClause().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()))); + } - @Override - public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ctx) { - Join.JoinType joinType = getJoinType(ctx.joinType()); - if (ctx.joinCriteria() == null) { - joinType = Join.JoinType.CROSS; - } - Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - String leftAlias = ctx.sideAlias().leftAlias.getText(); - String rightAlias = ctx.sideAlias().rightAlias.getText(); - // TODO when sub-search is supported, this part need to change. Now relation is the only supported plan for right side - UnresolvedPlan right = new SubqueryAlias(rightAlias, new Relation(this.internalVisitExpression(ctx.tableSource()), rightAlias)); - Optional joinCondition = - ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); - - return new Join(right, leftAlias, rightAlias, joinType, joinCondition, joinHint); - } + @Override + public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ctx) { + Join.JoinType joinType = getJoinType(ctx.joinType()); + if (ctx.joinCriteria() == null) { + joinType = Join.JoinType.CROSS; + } + Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); + String leftAlias = ctx.sideAlias().leftAlias.getText(); + String rightAlias = ctx.sideAlias().rightAlias.getText(); + // TODO when sub-search is supported, this part need to change. Now relation is the only supported plan for right side + UnresolvedPlan right = new SubqueryAlias(rightAlias, new Relation(this.internalVisitExpression(ctx.tableSource()), rightAlias)); + Optional joinCondition = + ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); + + return new Join(right, leftAlias, rightAlias, joinType, joinCondition, joinHint); + } - private Join.JoinHint getJoinHint(OpenSearchPPLParser.JoinHintListContext ctx) { - Join.JoinHint joinHint; - if (ctx == null) { - joinHint = new Join.JoinHint(); - } else { - joinHint = new Join.JoinHint( - ctx.hintPair().stream() - .map(pCtx -> expressionBuilder.visit(pCtx)) - .filter(e -> e instanceof EqualTo) - .map(e -> (EqualTo) e) - .collect(Collectors.toMap( - k -> k.getLeft().toString(), // always literal - v -> v.getRight().toString(), // always literal - (v1, v2) -> v2, - LinkedHashMap::new))); - } - return joinHint; - } + private Join.JoinHint getJoinHint(OpenSearchPPLParser.JoinHintListContext ctx) { + Join.JoinHint joinHint; + if (ctx == null) { + joinHint = new Join.JoinHint(); + } else { + joinHint = new Join.JoinHint( + ctx.hintPair().stream() + .map(pCtx -> expressionBuilder.visit(pCtx)) + .filter(e -> e instanceof EqualTo) + .map(e -> (EqualTo) e) + .collect(Collectors.toMap( + k -> k.getLeft().toString(), // always literal + v -> v.getRight().toString(), // always literal + (v1, v2) -> v2, + LinkedHashMap::new))); + } + return joinHint; + } - private Join.JoinType getJoinType(OpenSearchPPLParser.JoinTypeContext ctx) { - Join.JoinType joinType; - if (ctx == null) { - joinType = Join.JoinType.INNER; - } else if (ctx.INNER() != null) { - joinType = Join.JoinType.INNER; - } else if (ctx.SEMI() != null) { - joinType = Join.JoinType.SEMI; - } else if (ctx.ANTI() != null) { - joinType = Join.JoinType.ANTI; - } else if (ctx.LEFT() != null) { - joinType = Join.JoinType.LEFT; - } else if (ctx.RIGHT() != null) { - joinType = Join.JoinType.RIGHT; - } else if (ctx.CROSS() != null) { - joinType = Join.JoinType.CROSS; - } else if (ctx.FULL() != null) { - joinType = Join.JoinType.FULL; - } else { - joinType = Join.JoinType.INNER; - } - return joinType; - } + private Join.JoinType getJoinType(OpenSearchPPLParser.JoinTypeContext ctx) { + Join.JoinType joinType; + if (ctx == null) { + joinType = Join.JoinType.INNER; + } else if (ctx.INNER() != null) { + joinType = Join.JoinType.INNER; + } else if (ctx.SEMI() != null) { + joinType = Join.JoinType.SEMI; + } else if (ctx.ANTI() != null) { + joinType = Join.JoinType.ANTI; + } else if (ctx.LEFT() != null) { + joinType = Join.JoinType.LEFT; + } else if (ctx.RIGHT() != null) { + joinType = Join.JoinType.RIGHT; + } else if (ctx.CROSS() != null) { + joinType = Join.JoinType.CROSS; + } else if (ctx.FULL() != null) { + joinType = Join.JoinType.FULL; + } else { + joinType = Join.JoinType.INNER; + } + return joinType; + } - /** Fields command. */ - @Override - public UnresolvedPlan visitFieldsCommand(OpenSearchPPLParser.FieldsCommandContext ctx) { - return new Project( - ctx.fieldList().fieldExpression().stream() - .map(this::internalVisitExpression) - .collect(Collectors.toList()), - ArgumentFactory.getArgumentList(ctx)); - } + /** + * Fields command. + */ + @Override + public UnresolvedPlan visitFieldsCommand(OpenSearchPPLParser.FieldsCommandContext ctx) { + return new Project( + ctx.fieldList().fieldExpression().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()), + ArgumentFactory.getArgumentList(ctx)); + } - /** Rename command. */ - @Override - public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContext ctx) { - return new Rename( - ctx.renameClasue().stream() - .map( - ct -> - new Alias( - ct.renamedField.getText(), - internalVisitExpression(ct.orignalField))) - .collect(Collectors.toList())); - } + /** + * Rename command. + */ + @Override + public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContext ctx) { + return new Rename( + ctx.renameClasue().stream() + .map( + ct -> + new Alias( + ct.renamedField.getText(), + internalVisitExpression(ct.orignalField))) + .collect(Collectors.toList())); + } - /** Stats command. */ - @Override - public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext ctx) { - ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); - for (OpenSearchPPLParser.StatsAggTermContext aggCtx : ctx.statsAggTerm()) { - UnresolvedExpression aggExpression = internalVisitExpression(aggCtx.statsFunction()); - String name = - aggCtx.alias == null - ? getTextInQuery(aggCtx) - : aggCtx.alias.getText(); - Alias alias = new Alias(name, aggExpression); - aggListBuilder.add(alias); - } - - List groupList = - Optional.ofNullable(ctx.statsByClause()) - .map(OpenSearchPPLParser.StatsByClauseContext::fieldList) - .map( - expr -> - expr.fieldExpression().stream() + /** + * Stats command. + */ + @Override + public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext ctx) { + ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + for (OpenSearchPPLParser.StatsAggTermContext aggCtx : ctx.statsAggTerm()) { + UnresolvedExpression aggExpression = internalVisitExpression(aggCtx.statsFunction()); + String name = + aggCtx.alias == null + ? getTextInQuery(aggCtx) + : aggCtx.alias.getText(); + Alias alias = new Alias(name, aggExpression); + aggListBuilder.add(alias); + } + + List groupList = + Optional.ofNullable(ctx.statsByClause()) + .map(OpenSearchPPLParser.StatsByClauseContext::fieldList) .map( - groupCtx -> - (UnresolvedExpression) - new Alias( - getTextInQuery(groupCtx), - internalVisitExpression(groupCtx))) - .collect(Collectors.toList())) - .orElse(emptyList()); - - UnresolvedExpression span = - Optional.ofNullable(ctx.statsByClause()) - .map(OpenSearchPPLParser.StatsByClauseContext::bySpanClause) - .map(this::internalVisitExpression) - .orElse(null); - - Aggregation aggregation = - new Aggregation( - aggListBuilder.build(), - emptyList(), - groupList, - span, - ArgumentFactory.getArgumentList(ctx)); - return aggregation; - } + expr -> + expr.fieldExpression().stream() + .map( + groupCtx -> + (UnresolvedExpression) + new Alias( + getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) + .orElse(emptyList()); + + UnresolvedExpression span = + Optional.ofNullable(ctx.statsByClause()) + .map(OpenSearchPPLParser.StatsByClauseContext::bySpanClause) + .map(this::internalVisitExpression) + .orElse(null); + + Aggregation aggregation = + new Aggregation( + aggListBuilder.build(), + emptyList(), + groupList, + span, + ArgumentFactory.getArgumentList(ctx)); + return aggregation; + } - /** Dedup command. */ - @Override - public UnresolvedPlan visitDedupCommand(OpenSearchPPLParser.DedupCommandContext ctx) { - return new Dedupe(ArgumentFactory.getArgumentList(ctx), getFieldList(ctx.fieldList())); - } + /** + * Dedup command. + */ + @Override + public UnresolvedPlan visitDedupCommand(OpenSearchPPLParser.DedupCommandContext ctx) { + return new Dedupe(ArgumentFactory.getArgumentList(ctx), getFieldList(ctx.fieldList())); + } - /** Head command visitor. */ - @Override - public UnresolvedPlan visitHeadCommand(OpenSearchPPLParser.HeadCommandContext ctx) { - Integer size = ctx.number != null ? Integer.parseInt(ctx.number.getText()) : 10; - Integer from = ctx.from != null ? Integer.parseInt(ctx.from.getText()) : 0; - return new Head(size, from); - } + /** + * Head command visitor. + */ + @Override + public UnresolvedPlan visitHeadCommand(OpenSearchPPLParser.HeadCommandContext ctx) { + Integer size = ctx.number != null ? Integer.parseInt(ctx.number.getText()) : 10; + Integer from = ctx.from != null ? Integer.parseInt(ctx.from.getText()) : 0; + return new Head(size, from); + } - /** Sort command. */ - @Override - public UnresolvedPlan visitSortCommand(OpenSearchPPLParser.SortCommandContext ctx) { - return new Sort( - ctx.sortbyClause().sortField().stream() - .map(sort -> (Field) internalVisitExpression(sort)) - .collect(Collectors.toList())); - } + /** + * Sort command. + */ + @Override + public UnresolvedPlan visitSortCommand(OpenSearchPPLParser.SortCommandContext ctx) { + return new Sort( + ctx.sortbyClause().sortField().stream() + .map(sort -> (Field) internalVisitExpression(sort)) + .collect(Collectors.toList())); + } - /** Eval command. */ - @Override - public UnresolvedPlan visitEvalCommand(OpenSearchPPLParser.EvalCommandContext ctx) { - return new Eval( - ctx.evalClause().stream() - .map(ct -> (Let) internalVisitExpression(ct)) - .collect(Collectors.toList())); - } + /** + * Eval command. + */ + @Override + public UnresolvedPlan visitEvalCommand(OpenSearchPPLParser.EvalCommandContext ctx) { + return new Eval( + ctx.evalClause().stream() + .map(ct -> (Let) internalVisitExpression(ct)) + .collect(Collectors.toList())); + } - private List getGroupByList(OpenSearchPPLParser.ByClauseContext ctx) { - return ctx.fieldList().fieldExpression().stream() - .map(this::internalVisitExpression) - .collect(Collectors.toList()); - } + private List getGroupByList(OpenSearchPPLParser.ByClauseContext ctx) { + return ctx.fieldList().fieldExpression().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()); + } - private List getFieldList(OpenSearchPPLParser.FieldListContext ctx) { - return ctx.fieldExpression().stream() - .map(field -> (Field) internalVisitExpression(field)) - .collect(Collectors.toList()); - } + private List getFieldList(OpenSearchPPLParser.FieldListContext ctx) { + return ctx.fieldExpression().stream() + .map(field -> (Field) internalVisitExpression(field)) + .collect(Collectors.toList()); + } - @Override - public UnresolvedPlan visitGrokCommand(OpenSearchPPLParser.GrokCommandContext ctx) { - UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); - Literal pattern = (Literal) internalVisitExpression(ctx.pattern); + @Override + public UnresolvedPlan visitGrokCommand(OpenSearchPPLParser.GrokCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + Literal pattern = (Literal) internalVisitExpression(ctx.pattern); - return new Parse(ParseMethod.GROK, sourceField, pattern, ImmutableMap.of()); - } + return new Parse(ParseMethod.GROK, sourceField, pattern, ImmutableMap.of()); + } - @Override - public UnresolvedPlan visitParseCommand(OpenSearchPPLParser.ParseCommandContext ctx) { - UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); - Literal pattern = (Literal) internalVisitExpression(ctx.pattern); + @Override + public UnresolvedPlan visitParseCommand(OpenSearchPPLParser.ParseCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + Literal pattern = (Literal) internalVisitExpression(ctx.pattern); - return new Parse(ParseMethod.REGEX, sourceField, pattern, ImmutableMap.of()); - } + return new Parse(ParseMethod.REGEX, sourceField, pattern, ImmutableMap.of()); + } - @Override - public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandContext ctx) { - UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); - ImmutableMap.Builder builder = ImmutableMap.builder(); - ctx.patternsParameter() - .forEach( - x -> { - builder.put( - x.children.get(0).toString(), - (Literal) internalVisitExpression(x.children.get(2))); - }); - java.util.Map arguments = builder.build(); - Literal pattern = arguments.getOrDefault("pattern", new Literal("", DataType.STRING)); - - return new Parse(ParseMethod.PATTERNS, sourceField, pattern, arguments); - } + @Override + public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.patternsParameter() + .forEach( + x -> { + builder.put( + x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + java.util.Map arguments = builder.build(); + Literal pattern = arguments.getOrDefault("pattern", new Literal("", DataType.STRING)); + + return new Parse(ParseMethod.PATTERNS, sourceField, pattern, arguments); + } - /** Lookup command */ - @Override - public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { - Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); - Lookup.OutputStrategy strategy = - ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; - java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); - java.util.Map outputCandidateList = - ctx.APPEND() == null && ctx.REPLACE() == null ? emptyMap() : buildLookupPair(ctx.outputCandidateList().lookupPair()); - return new Lookup(new SubqueryAlias(lookupRelation, "_l"), lookupMappingList, strategy, outputCandidateList); - } + /** + * Lookup command + */ + @Override + public UnresolvedPlan visitLookupCommand(OpenSearchPPLParser.LookupCommandContext ctx) { + Relation lookupRelation = new Relation(this.internalVisitExpression(ctx.tableSource())); + Lookup.OutputStrategy strategy = + ctx.APPEND() != null ? Lookup.OutputStrategy.APPEND : Lookup.OutputStrategy.REPLACE; + java.util.Map lookupMappingList = buildLookupPair(ctx.lookupMappingList().lookupPair()); + java.util.Map outputCandidateList = + ctx.APPEND() == null && ctx.REPLACE() == null ? emptyMap() : buildLookupPair(ctx.outputCandidateList().lookupPair()); + return new Lookup(new SubqueryAlias(lookupRelation, "_l"), lookupMappingList, strategy, outputCandidateList); + } - private java.util.Map buildLookupPair(List ctx) { - return ctx.stream() - .map(of -> expressionBuilder.visitLookupPair(of)) - .map(And.class::cast) - .collect(Collectors.toMap(and -> (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); - } + private java.util.Map buildLookupPair(List ctx) { + return ctx.stream() + .map(of -> expressionBuilder.visitLookupPair(of)) + .map(And.class::cast) + .collect(Collectors.toMap(and -> (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); + } - /** Top command. */ - @Override - public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { - ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); - ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); - ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), - Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); - String name = field.qualifiedName().getText(); - Alias alias = new Alias("count_"+name, aggExpression); - aggListBuilder.add(alias); - // group by the `field-list` as the mandatory groupBy fields - groupListBuilder.add(internalVisitExpression(field)); - }); - - // group by the `by-clause` as the optional groupBy fields - groupListBuilder.addAll( - Optional.ofNullable(ctx.byClause()) - .map(OpenSearchPPLParser.ByClauseContext::fieldList) - .map( - expr -> - expr.fieldExpression().stream() - .map( - groupCtx -> - (UnresolvedExpression) - new Alias( - getTextInQuery(groupCtx), - internalVisitExpression(groupCtx))) - .collect(Collectors.toList())) - .orElse(emptyList()) - ); - UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); - TopAggregation aggregation = - new TopAggregation( - Optional.ofNullable((Literal) unresolvedPlan), - aggListBuilder.build(), - aggListBuilder.build(), - groupListBuilder.build()); - return aggregation; - } + /** + * Top command. + */ + @Override + public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { + ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + ctx.fieldList().fieldExpression().forEach(field -> { + UnresolvedExpression aggExpression = new AggregateFunction("count", internalVisitExpression(field), + Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); + String name = field.qualifiedName().getText(); + Alias alias = new Alias("count_" + name, aggExpression); + aggListBuilder.add(alias); + // group by the `field-list` as the mandatory groupBy fields + groupListBuilder.add(internalVisitExpression(field)); + }); + + // group by the `by-clause` as the optional groupBy fields + groupListBuilder.addAll( + Optional.ofNullable(ctx.byClause()) + .map(OpenSearchPPLParser.ByClauseContext::fieldList) + .map( + expr -> + expr.fieldExpression().stream() + .map( + groupCtx -> + (UnresolvedExpression) + new Alias( + getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) + .orElse(emptyList()) + ); + UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + TopAggregation aggregation = + new TopAggregation( + Optional.ofNullable((Literal) unresolvedPlan), + aggListBuilder.build(), + aggListBuilder.build(), + groupListBuilder.build()); + return aggregation; + } @Override public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryCommandContext ctx) { return new FieldSummary(ctx.fieldsummaryParameter().stream().map(arg -> expressionBuilder.visit(arg)).collect(Collectors.toList())); } - /** Rare command. */ - @Override - public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { - ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); - ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); - ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), - Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); - String name = field.qualifiedName().getText(); - Alias alias = new Alias("count_"+name, aggExpression); - aggListBuilder.add(alias); - // group by the `field-list` as the mandatory groupBy fields - groupListBuilder.add(internalVisitExpression(field)); - }); - - // group by the `by-clause` as the optional groupBy fields - groupListBuilder.addAll( - Optional.ofNullable(ctx.byClause()) - .map(OpenSearchPPLParser.ByClauseContext::fieldList) - .map( - expr -> - expr.fieldExpression().stream() - .map( - groupCtx -> - (UnresolvedExpression) - new Alias( - getTextInQuery(groupCtx), - internalVisitExpression(groupCtx))) - .collect(Collectors.toList())) - .orElse(emptyList()) - ); - RareAggregation aggregation = - new RareAggregation( - aggListBuilder.build(), - aggListBuilder.build(), - groupListBuilder.build()); - return aggregation; - } + /** + * Rare command. + */ + @Override + public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { + ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + ctx.fieldList().fieldExpression().forEach(field -> { + UnresolvedExpression aggExpression = new AggregateFunction("count", internalVisitExpression(field), + Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); + String name = field.qualifiedName().getText(); + Alias alias = new Alias("count_" + name, aggExpression); + aggListBuilder.add(alias); + // group by the `field-list` as the mandatory groupBy fields + groupListBuilder.add(internalVisitExpression(field)); + }); + + // group by the `by-clause` as the optional groupBy fields + groupListBuilder.addAll( + Optional.ofNullable(ctx.byClause()) + .map(OpenSearchPPLParser.ByClauseContext::fieldList) + .map( + expr -> + expr.fieldExpression().stream() + .map( + groupCtx -> + (UnresolvedExpression) + new Alias( + getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) + .orElse(emptyList()) + ); + RareAggregation aggregation = + new RareAggregation( + aggListBuilder.build(), + aggListBuilder.build(), + groupListBuilder.build()); + return aggregation; + } - /** From clause. */ - @Override - public UnresolvedPlan visitFromClause(OpenSearchPPLParser.FromClauseContext ctx) { - return visitTableSourceClause(ctx.tableSourceClause()); - } + /** + * From clause. + */ + @Override + public UnresolvedPlan visitFromClause(OpenSearchPPLParser.FromClauseContext ctx) { + return visitTableSourceClause(ctx.tableSourceClause()); + } - @Override - public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return new Relation( - ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); - } + @Override + public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { + return new Relation( + ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + } - @Override - public UnresolvedPlan visitTableFunction(OpenSearchPPLParser.TableFunctionContext ctx) { - ImmutableList.Builder builder = ImmutableList.builder(); - ctx.functionArgs() - .functionArg() - .forEach( - arg -> { - String argName = (arg.ident() != null) ? arg.ident().getText() : null; - builder.add( - new UnresolvedArgument( - argName, this.internalVisitExpression(arg.valueExpression()))); - }); - return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); - } + @Override + public UnresolvedPlan visitTableFunction(OpenSearchPPLParser.TableFunctionContext ctx) { + ImmutableList.Builder builder = ImmutableList.builder(); + ctx.functionArgs() + .functionArg() + .forEach( + arg -> { + String argName = (arg.ident() != null) ? arg.ident().getText() : null; + builder.add( + new UnresolvedArgument( + argName, this.internalVisitExpression(arg.valueExpression()))); + }); + return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); + } - /** Navigate to & build AST expression. */ - private UnresolvedExpression internalVisitExpression(ParseTree tree) { - return expressionBuilder.visit(tree); - } + /** + * Navigate to & build AST expression. + */ + private UnresolvedExpression internalVisitExpression(ParseTree tree) { + return expressionBuilder.visit(tree); + } - /** Simply return non-default value for now. */ - @Override - protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPlan nextResult) { - if (nextResult != defaultResult()) { - return nextResult; + /** + * Simply return non-default value for now. + */ + @Override + protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPlan nextResult) { + if (nextResult != defaultResult()) { + return nextResult; + } + return aggregate; } - return aggregate; - } - /** Kmeans command. */ - @Override - public UnresolvedPlan visitKmeansCommand(OpenSearchPPLParser.KmeansCommandContext ctx) { - ImmutableMap.Builder builder = ImmutableMap.builder(); - ctx.kmeansParameter() - .forEach( - x -> { - builder.put( - x.children.get(0).toString(), - (Literal) internalVisitExpression(x.children.get(2))); - }); - return new Kmeans(builder.build()); - } + /** + * Kmeans command. + */ + @Override + public UnresolvedPlan visitKmeansCommand(OpenSearchPPLParser.KmeansCommandContext ctx) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.kmeansParameter() + .forEach( + x -> { + builder.put( + x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + return new Kmeans(builder.build()); + } - @Override - public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandContext ctx) { - // ctx contain result of parsing fillnull command. Lets transform it to UnresolvedPlan which is FillNull - FillNullWithTheSameValueContext sameValueContext = ctx.fillNullWithTheSameValue(); - FillNullWithFieldVariousValuesContext variousValuesContext = ctx.fillNullWithFieldVariousValues(); - if (sameValueContext != null) { - // todo consider using expression instead of Literal - UnresolvedExpression replaceNullWithMe = internalVisitExpression(sameValueContext.nullReplacement().expression()); - List fieldsToReplace = sameValueContext.nullableField() - .stream() - .map(this::internalVisitExpression) - .map(Field.class::cast) - .collect(Collectors.toList()); - return new FillNull(ofSameValue(replaceNullWithMe, fieldsToReplace)); - } else if (variousValuesContext != null) { - List nullableFieldFills = IntStream.range(0, variousValuesContext.nullableField().size()) - .mapToObj(index -> { - variousValuesContext.nullableField(index); - UnresolvedExpression replaceNullWithMe = internalVisitExpression(variousValuesContext.nullReplacement(index).expression()); - Field nullableFieldReference = (Field) internalVisitExpression(variousValuesContext.nullableField(index)); - return new NullableFieldFill(nullableFieldReference, replaceNullWithMe); - }) - .collect(Collectors.toList()); - return new FillNull(ofVariousValue(nullableFieldFills)); - } else { - throw new SyntaxCheckException("Invalid fillnull command"); + @Override + public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandContext ctx) { + // ctx contain result of parsing fillnull command. Lets transform it to UnresolvedPlan which is FillNull + FillNullWithTheSameValueContext sameValueContext = ctx.fillNullWithTheSameValue(); + FillNullWithFieldVariousValuesContext variousValuesContext = ctx.fillNullWithFieldVariousValues(); + if (sameValueContext != null) { + // todo consider using expression instead of Literal + UnresolvedExpression replaceNullWithMe = internalVisitExpression(sameValueContext.nullReplacement().expression()); + List fieldsToReplace = sameValueContext.nullableField() + .stream() + .map(this::internalVisitExpression) + .map(Field.class::cast) + .collect(Collectors.toList()); + return new FillNull(ofSameValue(replaceNullWithMe, fieldsToReplace)); + } else if (variousValuesContext != null) { + List nullableFieldFills = IntStream.range(0, variousValuesContext.nullableField().size()) + .mapToObj(index -> { + variousValuesContext.nullableField(index); + UnresolvedExpression replaceNullWithMe = internalVisitExpression(variousValuesContext.nullReplacement(index).expression()); + Field nullableFieldReference = (Field) internalVisitExpression(variousValuesContext.nullableField(index)); + return new NullableFieldFill(nullableFieldReference, replaceNullWithMe); + }) + .collect(Collectors.toList()); + return new FillNull(ofVariousValue(nullableFieldFills)); + } else { + throw new SyntaxCheckException("Invalid fillnull command"); + } } - } - /** AD command. */ - @Override - public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { - throw new RuntimeException("AD Command is not supported "); + /** + * AD command. + */ + @Override + public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { + throw new RuntimeException("AD Command is not supported "); - } + } - /** ml command. */ - @Override - public UnresolvedPlan visitMlCommand(OpenSearchPPLParser.MlCommandContext ctx) { - throw new RuntimeException("ML Command is not supported "); - } + /** + * ml command. + */ + @Override + public UnresolvedPlan visitMlCommand(OpenSearchPPLParser.MlCommandContext ctx) { + throw new RuntimeException("ML Command is not supported "); + } - /** Get original text in query. */ - private String getTextInQuery(ParserRuleContext ctx) { - Token start = ctx.getStart(); - Token stop = ctx.getStop(); - return query.substring(start.getStartIndex(), stop.getStopIndex() + 1); - } + /** + * Get original text in query. + */ + private String getTextInQuery(ParserRuleContext ctx) { + Token start = ctx.getStart(); + Token stop = ctx.getStop(); + return query.substring(start.getStartIndex(), stop.getStopIndex() + 1); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 0040fbc9e..dda0bf5c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -51,7 +51,6 @@ import java.util.stream.IntStream; import java.util.stream.Stream; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.EXCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.TOPVALUES; @@ -193,14 +192,6 @@ public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.F return new NamedExpression(INCLUDEFIELDS,new FieldList(includeFields)); } - @Override - public UnresolvedExpression visitFieldsummaryExcludeFields(OpenSearchPPLParser.FieldsummaryExcludeFieldsContext ctx) { - List excludeFields = ctx.fieldList().fieldExpression().stream() - .map(this::visitFieldExpression) - .map(p->(Field)p) - .collect(Collectors.toList()); - return new NamedExpression(EXCLUDEFIELDS,new FieldList(excludeFields)); - } @Override public UnresolvedExpression visitFieldsummaryTopValues(OpenSearchPPLParser.FieldsummaryTopValuesContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java new file mode 100644 index 000000000..c26f857cc --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Subtract; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.opensearch.sql.ast.tree.FieldSummary; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; +import java.util.Collections; + +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface FieldSummaryTransformer { + + String COUNT = "Count"; + String COUNT_DISTINCT = "CountDistinct"; + String MAX = "Max"; + String MIN = "Min"; + String AVG = "Avg"; + String TYPE = "Type"; + String TOP_VALUES = "TopValues"; + String NULLS = "Nulls"; + + String FIELD = "Field"; + + /** + * translate the field summary into the following query: + * SELECT + * -- For column1 --- + * 'column1' AS Field, + * COUNT(column1) AS Count, + * COUNT(DISTINCT column1) AS Distinct, + * MIN(column1) AS Min, + * MAX(column1) AS Max, + * AVG(CAST(column1 AS DOUBLE)) AS Avg, + * typeof(column1) AS Type, + * COLLECT_LIST(STRUCT(column1, COUNT(column1))) AS top_values, + * COUNT(*) - COUNT(column1) AS Nulls, + * + * -- For column2 --- + * 'column2' AS Field, + * COUNT(column2) AS Count, + * COUNT(DISTINCT column2) AS Distinct, + * MIN(column2) AS Min, + * MAX(column2) AS Max, + * AVG(CAST(column2 AS DOUBLE)) AS Avg, + * typeof(column2) AS Type, + * COLLECT_LIST(STRUCT(column2, COUNT(column2))) AS top_values, + * COUNT(*) - COUNT(column2) AS Nulls + * FROM ... + */ + static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { + fieldSummary.getIncludeFields().forEach(field -> { + Literal fieldLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(field.getField().toString(), StringType); + context.withProjectedFields(Collections.singletonList(field)); + //Alias for the field name as Field + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(fieldLiteral, + FIELD, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + //Alias for the count(field) as Count + UnresolvedFunction count = new UnresolvedFunction(seq("COUNT"), seq(fieldLiteral), false, empty(), false); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(count, + COUNT, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + //Alias for the count(DISTINCT field) as CountDistinct + UnresolvedFunction countDistinct = new UnresolvedFunction(seq("COUNT"), seq(fieldLiteral), true, empty(), false); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(countDistinct, + COUNT_DISTINCT, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + //Alias for the MAX(field) as MAX + UnresolvedFunction max = new UnresolvedFunction(seq("MAX"), seq(fieldLiteral), false, empty(), false); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(max, + MAX, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + //Alias for the MAX(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq("MIN"), seq(fieldLiteral), false, empty(), false); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(min, + MIN, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + //Alias for the AVG(field) as Avg + UnresolvedFunction avg = new UnresolvedFunction(seq("AVG"), seq(fieldLiteral), false, empty(), false); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(avg, + AVG, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + //Alias for the typeOf(field) as Type + UnresolvedFunction type = new UnresolvedFunction(seq("TYPEOF"), seq(fieldLiteral), false, empty(), false); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(type, + TYPE, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + false + ); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()) + )); + + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction( + seq("COUNT"), + seq(org.apache.spark.sql.catalyst.expressions.Literal.create(1, IntegerType)), + false, + empty(), + false + ); + + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()) + )); + }); + + return context.getPlan(); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala similarity index 97% rename from ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala rename to ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index 620fdcaab..49f628a4b 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -15,7 +15,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers -class PPLLogicalPlanFieldsummaryCommandTranslatorTestSuite +class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite with PlanTest with LogicalPlanTestUtils From 4f303062228ed30816c4ff8d42ab6da939ef34aa Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 10 Oct 2024 17:31:37 -0700 Subject: [PATCH 03/12] update sample query fix scala style format Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 99 +++++++++++++++++ .../opensearch/sql/ast/tree/FieldSummary.java | 4 +- .../ppl/utils/FieldSummaryTransformer.java | 105 +++++++++++------- ...eldSummaryCommandTranslatorTestSuite.scala | 13 ++- 4 files changed, 172 insertions(+), 49 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala new file mode 100644 index 000000000..fc2e8caac --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFieldSummaryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNullableTableHttpLog(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test fillnull with one null replacement value and one column") { +// val frame = sql(s""" +// | source = $testTable | fieldsummary includefields= status_code, id, response_time topvalues=5 nulls=true +// | """.stripMargin) + + val frame = sql(s""" + | SELECT + | 'status_code' AS Field, + | COUNT(status_code) AS Count, + | COUNT(DISTINCT status_code) AS Distinct, + | MIN(status_code) AS Min, + | MAX(status_code) AS Max, + | AVG(CAST(status_code AS DOUBLE)) AS Avg, + | typeof(status_code) AS Type, + | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) + | FROM ( + | SELECT status_code, COUNT(*) AS count_status + | FROM $testTable + | GROUP BY status_code + | ORDER BY count_status DESC + | LIMIT 5 + | )) AS top_values, + | COUNT(*) - COUNT(status_code) AS Nulls + | FROM $testTable + | GROUP BY typeof(status_code) + | + | UNION ALL + | + | SELECT + | 'id' AS Field, + | COUNT(id) AS Count, + | COUNT(DISTINCT id) AS Distinct, + | MIN(id) AS Min, + | MAX(id) AS Max, + | AVG(CAST(id AS DOUBLE)) AS Avg, + | typeof(id) AS Type, + | (SELECT COLLECT_LIST(STRUCT(id, count_id)) + | FROM ( + | SELECT id, COUNT(*) AS count_id + | FROM $testTable + | GROUP BY id + | ORDER BY count_id DESC + | LIMIT 5 + | )) AS top_values, + | COUNT(*) - COUNT(id) AS Nulls + | FROM $testTable + | GROUP BY typeof(id) + |""".stripMargin) + + val results: Array[Row] = frame.collect() + // Print each row in a readable format + // scalastyle:off println + results.foreach(row => println(row.mkString(", "))) + // scalastyle:on println + +// val logicalPlan: LogicalPlan = frame.queryExecution.logical +// val expectedPlan = ? +// comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java index 1b6029803..a442d77fa 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -5,20 +5,20 @@ package org.opensearch.sql.ast.tree; -import com.google.common.collect.ImmutableList; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.NamedExpression; -import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; +import java.util.stream.Collectors; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index c26f857cc..225fc7a83 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -36,29 +36,50 @@ public interface FieldSummaryTransformer { /** * translate the field summary into the following query: - * SELECT - * -- For column1 --- - * 'column1' AS Field, - * COUNT(column1) AS Count, - * COUNT(DISTINCT column1) AS Distinct, - * MIN(column1) AS Min, - * MAX(column1) AS Max, - * AVG(CAST(column1 AS DOUBLE)) AS Avg, - * typeof(column1) AS Type, - * COLLECT_LIST(STRUCT(column1, COUNT(column1))) AS top_values, - * COUNT(*) - COUNT(column1) AS Nulls, + * ----------------------------------------------------- + * // for each column create statement: + * SELECT + * 'column-1' AS Field, + * COUNT(column-1) AS Count, + * COUNT(DISTINCT column-1) AS Distinct, + * MIN(column-1) AS Min, + * MAX(column-1) AS Max, + * AVG(CAST(column-1 AS DOUBLE)) AS Avg, + * typeof(column-1) AS Type, + * (SELECT COLLECT_LIST(STRUCT(column-1, count_status)) + * FROM ( + * SELECT column-1, COUNT(*) AS count_status + * FROM $testTable + * GROUP BY column-1 + * ORDER BY count_status DESC + * LIMIT 5 + * )) AS top_values, + * COUNT(*) - COUNT(column-1) AS Nulls + * FROM $testTable + * GROUP BY typeof(column-1) + * + * // union all queries + * UNION ALL * - * -- For column2 --- - * 'column2' AS Field, - * COUNT(column2) AS Count, - * COUNT(DISTINCT column2) AS Distinct, - * MIN(column2) AS Min, - * MAX(column2) AS Max, - * AVG(CAST(column2 AS DOUBLE)) AS Avg, - * typeof(column2) AS Type, - * COLLECT_LIST(STRUCT(column2, COUNT(column2))) AS top_values, - * COUNT(*) - COUNT(column2) AS Nulls - * FROM ... + * SELECT + * 'column-2' AS Field, + * COUNT(column-2) AS Count, + * COUNT(DISTINCT column-2) AS Distinct, + * MIN(column-2) AS Min, + * MAX(column-2) AS Max, + * AVG(CAST(column-2 AS DOUBLE)) AS Avg, + * typeof(column-2) AS Type, + * (SELECT COLLECT_LIST(STRUCT(column-2, count_column-2)) + * FROM ( + * SELECT column-, COUNT(*) AS count_column- + * FROM $testTable + * GROUP BY column-2 + * ORDER BY count_column- DESC + * LIMIT 5 + * )) AS top_values, + * COUNT(*) - COUNT(column-2) AS Nulls + * FROM $testTable + * GROUP BY typeof(column-2) */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getIncludeFields().forEach(field -> { @@ -154,25 +175,27 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont Option.empty(), seq(new java.util.ArrayList()) )); - - // Alias COUNT(*) - COUNT(column2) AS Nulls - UnresolvedFunction countStar = new UnresolvedFunction( - seq("COUNT"), - seq(org.apache.spark.sql.catalyst.expressions.Literal.create(1, IntegerType)), - false, - empty(), - false - ); - - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( - new Subtract(countStar, count), - NULLS, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()) - )); + + if (fieldSummary.isNulls()) { + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction( + seq("COUNT"), + seq(org.apache.spark.sql.catalyst.expressions.Literal.create(1, IntegerType)), + false, + empty(), + false + ); + + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()) + )); + } }); return context.getPlan(); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index 49f628a4b..e90d587ae 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -5,15 +5,16 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project} -import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} -import org.scalatest.matchers.should.Matchers class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite @@ -24,7 +25,7 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - test("test fieldsummary with `includefields=status_code,user_id,response_time`") { + ignore("test fieldsummary with `includefields=status_code,user_id,response_time`") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -52,4 +53,4 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } -} \ No newline at end of file +} From 7bcce2fe9818e182b8ccb13c6082b6bd2ea184cb Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 15 Oct 2024 16:53:10 -0700 Subject: [PATCH 04/12] support spark prior to 3.5 with its extended table identifier (existing table identifier only has 2 parts) Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 113 ++++---- .../opensearch/sql/ast/tree/FieldSummary.java | 6 +- .../function/BuiltinFunctionName.java | 1 + .../sql/ppl/CatalystQueryPlanVisitor.java | 21 ++ .../sql/ppl/utils/AggregatorTranslator.java | 28 +- .../ppl/utils/FieldSummaryTransformer.java | 242 +++++++++--------- ...eldSummaryCommandTranslatorTestSuite.scala | 44 ++-- 7 files changed, 255 insertions(+), 200 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index fc2e8caac..db00adae5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -36,62 +36,79 @@ class FlintSparkPPLFieldSummaryITSuite } } - test("test fillnull with one null replacement value and one column") { -// val frame = sql(s""" -// | source = $testTable | fieldsummary includefields= status_code, id, response_time topvalues=5 nulls=true -// | """.stripMargin) - + test("test fieldsummary with single field includefields(status_code) & nulls=true ") { val frame = sql(s""" - | SELECT - | 'status_code' AS Field, - | COUNT(status_code) AS Count, - | COUNT(DISTINCT status_code) AS Distinct, - | MIN(status_code) AS Min, - | MAX(status_code) AS Max, - | AVG(CAST(status_code AS DOUBLE)) AS Avg, - | typeof(status_code) AS Type, - | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) - | FROM ( - | SELECT status_code, COUNT(*) AS count_status - | FROM $testTable - | GROUP BY status_code - | ORDER BY count_status DESC - | LIMIT 5 - | )) AS top_values, - | COUNT(*) - COUNT(status_code) AS Nulls - | FROM $testTable - | GROUP BY typeof(status_code) - | - | UNION ALL - | - | SELECT - | 'id' AS Field, - | COUNT(id) AS Count, - | COUNT(DISTINCT id) AS Distinct, - | MIN(id) AS Min, - | MAX(id) AS Max, - | AVG(CAST(id AS DOUBLE)) AS Avg, - | typeof(id) AS Type, - | (SELECT COLLECT_LIST(STRUCT(id, count_id)) - | FROM ( - | SELECT id, COUNT(*) AS count_id - | FROM $testTable - | GROUP BY id - | ORDER BY count_id DESC - | LIMIT 5 - | )) AS top_values, - | COUNT(*) - COUNT(id) AS Nulls - | FROM $testTable - | GROUP BY typeof(id) - |""".stripMargin) + | source = $testTable | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + +/* + val frame = sql(s""" + | SELECT + | 'status_code' AS Field, + | COUNT(status_code) AS Count, + | COUNT(DISTINCT status_code) AS Distinct, + | MIN(status_code) AS Min, + | MAX(status_code) AS Max, + | AVG(CAST(status_code AS DOUBLE)) AS Avg, + | typeof(status_code) AS Type, + | COUNT(*) - COUNT(status_code) AS Nulls + | FROM $testTable + | GROUP BY typeof(status_code) + | """.stripMargin) +*/ + +// val frame = sql(s""" +// | SELECT +// | 'status_code' AS Field, +// | COUNT(status_code) AS Count, +// | COUNT(DISTINCT status_code) AS Distinct, +// | MIN(status_code) AS Min, +// | MAX(status_code) AS Max, +// | AVG(CAST(status_code AS DOUBLE)) AS Avg, +// | typeof(status_code) AS Type, +// | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) +// | FROM ( +// | SELECT status_code, COUNT(*) AS count_status +// | FROM $testTable +// | GROUP BY status_code +// | ORDER BY count_status DESC +// | LIMIT 5 +// | )) AS top_values, +// | COUNT(*) - COUNT(status_code) AS Nulls +// | FROM $testTable +// | GROUP BY typeof(status_code) +// | +// | UNION ALL +// | +// | SELECT +// | 'id' AS Field, +// | COUNT(id) AS Count, +// | COUNT(DISTINCT id) AS Distinct, +// | MIN(id) AS Min, +// | MAX(id) AS Max, +// | AVG(CAST(id AS DOUBLE)) AS Avg, +// | typeof(id) AS Type, +// | (SELECT COLLECT_LIST(STRUCT(id, count_id)) +// | FROM ( +// | SELECT id, COUNT(*) AS count_id +// | FROM $testTable +// | GROUP BY id +// | ORDER BY count_id DESC +// | LIMIT 5 +// | )) AS top_values, +// | COUNT(*) - COUNT(id) AS Nulls +// | FROM $testTable +// | GROUP BY typeof(id) +// |""".stripMargin) val results: Array[Row] = frame.collect() // Print each row in a readable format + val logicalPlan: LogicalPlan = frame.queryExecution.logical // scalastyle:off println results.foreach(row => println(row.mkString(", "))) + println(logicalPlan) // scalastyle:on println -// val logicalPlan: LogicalPlan = frame.queryExecution.logical // val expectedPlan = ? // comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java index a442d77fa..1d3b9ffed 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -8,7 +8,6 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Field; @@ -18,7 +17,6 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; -import java.util.stream.Collectors; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; @@ -30,7 +28,7 @@ public class FieldSummary extends UnresolvedPlan { private List includeFields; private int topValues; - private boolean nulls; + private boolean ignoreNull; private List collect; private UnresolvedPlan child; @@ -40,7 +38,7 @@ public FieldSummary(List collect) { .forEach(exp -> { switch (((NamedExpression) exp).getExpressionId()) { case NULLS: - this.nulls = (boolean) ((Literal) exp.getChild().get(0)).getValue(); + this.ignoreNull = (boolean) ((Literal) exp.getChild().get(0)).getValue(); break; case TOPVALUES: this.topValues = (int) ((Literal) exp.getChild().get(0)).getValue(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 6b549663a..1f58f92d1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -160,6 +160,7 @@ public enum BuiltinFunctionName { AVG(FunctionName.of("avg")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), + COUNT_DISTINCT(FunctionName.of("count_distinct")), MIN(FunctionName.of("min")), MAX(FunctionName.of("max")), // sample variance diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 76a7a0c79..8482f4be2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -383,6 +383,27 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { } @Override + /** + * 'Union false, false + * :- 'Aggregate ['typeof('status_code)], [status_code AS Field#20, 'COUNT('status_code) AS Count#21, 'COUNT(distinct 'status_code) AS Distinct#22, 'MIN('status_code) AS Min#23, 'MAX('status_code) AS Max#24, 'AVG(cast('status_code as double)) AS Avg#25, 'typeof('status_code) AS Type#26, scalar-subquery#28 [] AS top_values#29, ('COUNT(1) - 'COUNT('status_code)) AS Nulls#30] + * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, 'status_code, count_status, 'count_status)), None)] + * : : +- 'SubqueryAlias __auto_generated_subquery_name + * : : +- 'GlobalLimit 5 + * : : +- 'LocalLimit 5 + * : : +- 'Sort ['count_status DESC NULLS LAST], true + * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] + * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * +- 'Aggregate ['typeof('id)], [id AS Field#31, 'COUNT('id) AS Count#32, 'COUNT(distinct 'id) AS Distinct#33, 'MIN('id) AS Min#34, 'MAX('id) AS Max#35, 'AVG(cast('id as double)) AS Avg#36, 'typeof('id) AS Type#37, scalar-subquery#39 [] AS top_values#40, ('COUNT(1) - 'COUNT('id)) AS Nulls#41] + * : +- 'Project [unresolvedalias('COLLECT_LIST(struct(id, 'id, count_id, 'count_id)), None)] + * : +- 'SubqueryAlias __auto_generated_subquery_name + * : +- 'GlobalLimit 5 + * : +- 'LocalLimit 5 + * : +- 'Sort ['count_id DESC NULLS LAST], true + * : +- 'Aggregate ['id], ['id, 'COUNT(1) AS count_id#38] + * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getChild().get(0).accept(this, context); return FieldSummaryTransformer.translate(fieldSummary, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 3c367a948..93e2121d3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -12,10 +12,12 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; +import java.util.Optional; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -27,30 +29,36 @@ */ public interface AggregatorTranslator { + static String aggregationAlias(BuiltinFunctionName functionName, QualifiedName name) { + return functionName.name()+"("+name.toString()+")"; + } + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); + boolean distinct = aggregateFunction.getDistinct(); // Additional aggregation function operators will be added here - switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { + BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get(); + switch (functionName) { case MAX: - return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false); case MIN: - return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false); case AVG: - return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false); case COUNT: - return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false); case SUM: - return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false); case STDDEV_POP: - return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false); case STDDEV_SAMP: - return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), distinct, empty(),false); case PERCENTILE: - return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: - return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index 225fc7a83..9cc0582f3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -5,33 +5,38 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Alias; +import org.apache.spark.sql.catalyst.expressions.Alias$; import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Subtract; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ppl.CatalystPlanContext; import scala.Option; + import java.util.Collections; import static org.apache.spark.sql.types.DataTypes.IntegerType; import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT_DISTINCT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; +import static org.opensearch.sql.ppl.utils.AggregatorTranslator.aggregationAlias; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; public interface FieldSummaryTransformer { - String COUNT = "Count"; - String COUNT_DISTINCT = "CountDistinct"; - String MAX = "Max"; - String MIN = "Min"; - String AVG = "Avg"; - String TYPE = "Type"; String TOP_VALUES = "TopValues"; String NULLS = "Nulls"; - String FIELD = "Field"; /** @@ -39,147 +44,135 @@ public interface FieldSummaryTransformer { * ----------------------------------------------------- * // for each column create statement: * SELECT - * 'column-1' AS Field, - * COUNT(column-1) AS Count, - * COUNT(DISTINCT column-1) AS Distinct, - * MIN(column-1) AS Min, - * MAX(column-1) AS Max, - * AVG(CAST(column-1 AS DOUBLE)) AS Avg, - * typeof(column-1) AS Type, - * (SELECT COLLECT_LIST(STRUCT(column-1, count_status)) + * 'columnA' AS Field, + * COUNT(columnA) AS Count, + * COUNT(DISTINCT columnA) AS Distinct, + * MIN(columnA) AS Min, + * MAX(columnA) AS Max, + * AVG(CAST(columnA AS DOUBLE)) AS Avg, + * typeof(columnA) AS Type, + * (SELECT COLLECT_LIST(STRUCT(columnA, count_status)) * FROM ( - * SELECT column-1, COUNT(*) AS count_status + * SELECT columnA, COUNT(*) AS count_status * FROM $testTable - * GROUP BY column-1 + * GROUP BY columnA * ORDER BY count_status DESC * LIMIT 5 * )) AS top_values, - * COUNT(*) - COUNT(column-1) AS Nulls + * COUNT(*) - COUNT(columnA) AS Nulls * FROM $testTable - * GROUP BY typeof(column-1) + * GROUP BY typeof(columnA) * * // union all queries * UNION ALL * * SELECT - * 'column-2' AS Field, - * COUNT(column-2) AS Count, - * COUNT(DISTINCT column-2) AS Distinct, - * MIN(column-2) AS Min, - * MAX(column-2) AS Max, - * AVG(CAST(column-2 AS DOUBLE)) AS Avg, - * typeof(column-2) AS Type, - * (SELECT COLLECT_LIST(STRUCT(column-2, count_column-2)) + * 'columnB' AS Field, + * COUNT(columnB) AS Count, + * COUNT(DISTINCT columnB) AS Distinct, + * MIN(columnB) AS Min, + * MAX(columnB) AS Max, + * AVG(CAST(columnB AS DOUBLE)) AS Avg, + * typeof(columnB) AS Type, + * (SELECT COLLECT_LIST(STRUCT(columnB, count_columnB)) * FROM ( * SELECT column-, COUNT(*) AS count_column- * FROM $testTable - * GROUP BY column-2 + * GROUP BY columnB * ORDER BY count_column- DESC * LIMIT 5 * )) AS top_values, - * COUNT(*) - COUNT(column-2) AS Nulls + * COUNT(*) - COUNT(columnB) AS Nulls * FROM $testTable - * GROUP BY typeof(column-2) + * GROUP BY typeof(columnB) */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getIncludeFields().forEach(field -> { - Literal fieldLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(field.getField().toString(), StringType); + Literal fieldNameLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(field.getField().toString(), StringType); + UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(field.getField().getParts())); context.withProjectedFields(Collections.singletonList(field)); - //Alias for the field name as Field - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(fieldLiteral, - FIELD, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + + // Alias for the field name as Field + Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, + FIELD, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); //Alias for the count(field) as Count - UnresolvedFunction count = new UnresolvedFunction(seq("COUNT"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(count, - COUNT, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(count, + aggregationAlias(COUNT,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); //Alias for the count(DISTINCT field) as CountDistinct - UnresolvedFunction countDistinct = new UnresolvedFunction(seq("COUNT"), seq(fieldLiteral), true, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(countDistinct, - COUNT_DISTINCT, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - + UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); + Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, + aggregationAlias(COUNT_DISTINCT,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + //Alias for the MAX(field) as MAX - UnresolvedFunction max = new UnresolvedFunction(seq("MAX"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(max, - MAX, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - - //Alias for the MAX(field) as Min - UnresolvedFunction min = new UnresolvedFunction(seq("MIN"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(min, - MIN, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); + Alias maxAlias = Alias$.MODULE$.apply(max, + aggregationAlias(MAX,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); - //Alias for the AVG(field) as Avg - UnresolvedFunction avg = new UnresolvedFunction(seq("AVG"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(avg, - AVG, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + //Alias for the MIN(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); + Alias minAlias = Alias$.MODULE$.apply(min, + aggregationAlias(MIN,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); - //Alias for the typeOf(field) as Type - UnresolvedFunction type = new UnresolvedFunction(seq("TYPEOF"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(type, - TYPE, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - - // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values - CreateNamedStruct structExpr = new CreateNamedStruct(seq( - fieldLiteral, - count - )); - UnresolvedFunction collectList = new UnresolvedFunction( - seq("COLLECT_LIST"), - seq(structExpr), - false, + //Alias for the AVG(field) as Avg + UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); + Alias avgAlias = Alias$.MODULE$.apply(avg, + aggregationAlias(AVG,field.getField()), + NamedExpression.newExprId(), + seq(), empty(), - false - ); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( - collectList, - TOP_VALUES, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()) - )); - - if (fieldSummary.isNulls()) { + seq()); + + if (fieldSummary.getTopValues()>0) { + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + !fieldSummary.isIgnoreNull() + ); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + Option.empty(), + seq() + )); + } + + if (!fieldSummary.isIgnoreNull()) { // Alias COUNT(*) - COUNT(column2) AS Nulls UnresolvedFunction countStar = new UnresolvedFunction( - seq("COUNT"), + seq(COUNT.name()), seq(org.apache.spark.sql.catalyst.expressions.Literal.create(1, IntegerType)), false, empty(), @@ -191,11 +184,24 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont new Subtract(countStar, count), NULLS, NamedExpression.newExprId(), - seq(new java.util.ArrayList()), + seq(), Option.empty(), - seq(new java.util.ArrayList()) + seq() )); } + + //Alias for the typeOf(field) as Type + UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); + Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, + aggregationAlias(TYPEOF,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Aggregation + context.apply(p-> new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, typeOfAlias), p)); + }); return context.getPlan(); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index e90d587ae..f41f9808c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -9,12 +9,11 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Project} class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite @@ -25,32 +24,37 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - ignore("test fieldsummary with `includefields=status_code,user_id,response_time`") { + test("test fieldsummary with single field includefields(status_code) & nulls=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan( pplParser, - "source = t | fieldsummary includefields= status_code, user_id, response_time topvalues=5 nulls=true"), + "source = t | fieldsummary includefields= status_code nulls=true"), context) + // Define the table val table = UnresolvedRelation(Seq("t")) - val renameProjectList: Seq[NamedExpression] = - Seq( - UnresolvedStar(None), - Alias( - UnresolvedFunction( - "coalesce", - Seq(UnresolvedAttribute("column_name"), Literal("null replacement value")), - isDistinct = false), - "column_name")()) - val renameProject = Project(renameProjectList, table) - - val dropSourceColumn = - DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) - - val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) - comparePlans(expectedPlan, logPlan, checkAnalysis = false) + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "COUNT(status_code)")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), "COUNT_DISTINCT(status_code)")(), + Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN(status_code)")(), + Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MAX(status_code)")(), + Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG(status_code)")(), + Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")() + ) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")()), + aggregateExpressions, + table + ) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logPlan)) } } From c170208ab2ea7a47a1791efbc4d45258824822a1 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 15 Oct 2024 17:50:42 -0700 Subject: [PATCH 05/12] update union queries based summary Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 2 +- .../sql/ppl/CatalystPlanContext.java | 9 ++- .../ppl/utils/FieldSummaryTransformer.java | 41 ++++++----- ...eldSummaryCommandTranslatorTestSuite.scala | 71 ++++++++++++++++++- 4 files changed, 100 insertions(+), 23 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index db00adae5..519de2d54 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -36,7 +36,7 @@ class FlintSparkPPLFieldSummaryITSuite } } - test("test fieldsummary with single field includefields(status_code) & nulls=true ") { + ignore("test fieldsummary with single field includefields(status_code) & nulls=true ") { val frame = sql(s""" | source = $testTable | fieldsummary includefields= status_code nulls=true | """.stripMargin) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 46a016d1a..61762f616 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -154,7 +154,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + + public LogicalPlan applyBranches(List> plans) { + plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); + planBranches.remove(0); + return getPlan(); + } + /** * append plan with evolving plans branches * @@ -281,4 +287,5 @@ public static Optional findRelation(LogicalPlan plan) { // Return null if no UnresolvedRelation is found return Optional.empty(); } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index 9cc0582f3..fdaa04934 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -20,6 +20,9 @@ import scala.Option; import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; import static org.apache.spark.sql.types.DataTypes.IntegerType; import static org.apache.spark.sql.types.DataTypes.StringType; @@ -87,8 +90,8 @@ public interface FieldSummaryTransformer { * GROUP BY typeof(columnB) */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { - fieldSummary.getIncludeFields().forEach(field -> { - Literal fieldNameLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(field.getField().toString(), StringType); + List> aggBranches = fieldSummary.getIncludeFields().stream().map(field -> { + Literal fieldNameLiteral = Literal.create(field.getField().toString(), StringType); UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(field.getField().getParts())); context.withProjectedFields(Collections.singletonList(field)); @@ -103,7 +106,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the count(field) as Count UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); Alias countAlias = Alias$.MODULE$.apply(count, - aggregationAlias(COUNT,field.getField()), + aggregationAlias(COUNT, field.getField()), NamedExpression.newExprId(), seq(), empty(), @@ -112,7 +115,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the count(DISTINCT field) as CountDistinct UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, - aggregationAlias(COUNT_DISTINCT,field.getField()), + aggregationAlias(COUNT_DISTINCT, field.getField()), NamedExpression.newExprId(), seq(), empty(), @@ -121,7 +124,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the MAX(field) as MAX UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); Alias maxAlias = Alias$.MODULE$.apply(max, - aggregationAlias(MAX,field.getField()), + aggregationAlias(MAX, field.getField()), NamedExpression.newExprId(), seq(), empty(), @@ -130,7 +133,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the MIN(field) as Min UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); Alias minAlias = Alias$.MODULE$.apply(min, - aggregationAlias(MIN,field.getField()), + aggregationAlias(MIN, field.getField()), NamedExpression.newExprId(), seq(), empty(), @@ -139,13 +142,13 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the AVG(field) as Avg UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); Alias avgAlias = Alias$.MODULE$.apply(avg, - aggregationAlias(AVG,field.getField()), + aggregationAlias(AVG, field.getField()), NamedExpression.newExprId(), seq(), empty(), seq()); - if (fieldSummary.getTopValues()>0) { + if (fieldSummary.getTopValues() > 0) { // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values CreateNamedStruct structExpr = new CreateNamedStruct(seq( fieldLiteral, @@ -159,12 +162,12 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont !fieldSummary.isIgnoreNull() ); context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + Alias$.MODULE$.apply( collectList, TOP_VALUES, NamedExpression.newExprId(), seq(), - Option.empty(), + empty(), seq() )); } @@ -173,19 +176,19 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont // Alias COUNT(*) - COUNT(column2) AS Nulls UnresolvedFunction countStar = new UnresolvedFunction( seq(COUNT.name()), - seq(org.apache.spark.sql.catalyst.expressions.Literal.create(1, IntegerType)), + seq(Literal.create(1, IntegerType)), false, empty(), false ); context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + Alias$.MODULE$.apply( new Subtract(countStar, count), NULLS, NamedExpression.newExprId(), seq(), - Option.empty(), + empty(), seq() )); } @@ -193,17 +196,17 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the typeOf(field) as Type UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, - aggregationAlias(TYPEOF,field.getField()), + aggregationAlias(TYPEOF, field.getField()), NamedExpression.newExprId(), seq(), empty(), seq()); - - //Aggregation - context.apply(p-> new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, typeOfAlias), p)); - }); + //Aggregation + return (Function) p -> new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, typeOfAlias), p); + }).collect(Collectors.toList()); - return context.getPlan(); + LogicalPlan plan = context.applyBranches(aggBranches); + return plan; } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index f41f9808c..0a3800793 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Project, Union} class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite @@ -55,6 +55,73 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite ) val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) + } + + test("test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), "TYPEOF(id)")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), "COUNT(id)")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), "COUNT_DISTINCT(id)")(), + Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), "MIN(id)")(), + Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), "MAX(id)")(), + Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), "AVG(id)")(), + Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), "TYPEOF(id)")() + ), + table + ) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "COUNT(status_code)")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), "COUNT_DISTINCT(status_code)")(), + Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN(status_code)")(), + Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MAX(status_code)")(), + Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG(status_code)")(), + Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")() + ), + table + ) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "TYPEOF(request_path)")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "COUNT(request_path)")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), "COUNT_DISTINCT(request_path)")(), + Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MIN(request_path)")(), + Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MAX(request_path)")(), + Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "AVG(request_path)")(), + Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "TYPEOF(request_path)")() + ), + table + ) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) } } From 63c611811a26e40d9520682ac4ecbb463b2b0153 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 15 Oct 2024 20:13:43 -0700 Subject: [PATCH 06/12] update scala fmt style Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 10 +- ...eldSummaryCommandTranslatorTestSuite.scala | 158 +++++++++++++----- 2 files changed, 118 insertions(+), 50 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index 519de2d54..fb1988733 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -41,7 +41,7 @@ class FlintSparkPPLFieldSummaryITSuite | source = $testTable | fieldsummary includefields= status_code nulls=true | """.stripMargin) -/* + /* val frame = sql(s""" | SELECT | 'status_code' AS Field, @@ -53,9 +53,9 @@ class FlintSparkPPLFieldSummaryITSuite | typeof(status_code) AS Type, | COUNT(*) - COUNT(status_code) AS Nulls | FROM $testTable - | GROUP BY typeof(status_code) + | GROUP BY typeof(status_code) | """.stripMargin) -*/ + */ // val frame = sql(s""" // | SELECT @@ -76,8 +76,8 @@ class FlintSparkPPLFieldSummaryITSuite // | )) AS top_values, // | COUNT(*) - COUNT(status_code) AS Nulls // | FROM $testTable -// | GROUP BY typeof(status_code) -// | +// | GROUP BY typeof(status_code) +// | // | UNION ALL // | // | SELECT diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index 0a3800793..fa39a4fd4 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -9,6 +9,7 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} @@ -28,9 +29,7 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan( - pplParser, - "source = t | fieldsummary includefields= status_code nulls=true"), + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=true"), context) // Define the table @@ -39,26 +38,39 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite // Aggregate with functions applied to status_code val aggregateExpressions: Seq[NamedExpression] = Seq( Alias(Literal("status_code"), "Field")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "COUNT(status_code)")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), "COUNT_DISTINCT(status_code)")(), - Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN(status_code)")(), - Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MAX(status_code)")(), - Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG(status_code)")(), - Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")() - ) + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT(status_code)")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT(status_code)")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN(status_code)")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX(status_code)")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG(status_code)")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF(status_code)")()) // Define the aggregate plan with alias for TYPEOF in the aggregation val aggregatePlan = Aggregate( - groupingExpressions = Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")()), + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF(status_code)")()), aggregateExpressions, - table - ) + table) val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) // Compare the two plans comparePlans(expectedPlan, logPlan, false) } - test("test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -73,51 +85,107 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite // Aggregate with functions applied to status_code // Define the aggregate plan with alias for TYPEOF in the aggregation val aggregateIdPlan = Aggregate( - Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), "TYPEOF(id)")()), + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF(id)")()), Seq( Alias(Literal("id"), "Field")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), "COUNT(id)")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), "COUNT_DISTINCT(id)")(), - Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), "MIN(id)")(), - Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), "MAX(id)")(), - Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), "AVG(id)")(), - Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), "TYPEOF(id)")() - ), - table - ) + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT(id)")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "COUNT_DISTINCT(id)")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN(id)")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX(id)")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG(id)")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF(id)")()), + table) val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) - + // Aggregate with functions applied to status_code // Define the aggregate plan with alias for TYPEOF in the aggregation val aggregateStatusCodePlan = Aggregate( - Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")()), + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF(status_code)")()), Seq( Alias(Literal("status_code"), "Field")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "COUNT(status_code)")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), "COUNT_DISTINCT(status_code)")(), - Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN(status_code)")(), - Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MAX(status_code)")(), - Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG(status_code)")(), - Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")() - ), - table - ) + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT(status_code)")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT(status_code)")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN(status_code)")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX(status_code)")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG(status_code)")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF(status_code)")()), + table) val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) // Define the aggregate plan with alias for TYPEOF in the aggregation val aggregatePlan = Aggregate( - Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "TYPEOF(request_path)")()), + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF(request_path)")()), Seq( Alias(Literal("request_path"), "Field")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "COUNT(request_path)")(), - Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), "COUNT_DISTINCT(request_path)")(), - Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MIN(request_path)")(), - Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MAX(request_path)")(), - Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "AVG(request_path)")(), - Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "TYPEOF(request_path)")() - ), - table - ) + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT(request_path)")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "COUNT_DISTINCT(request_path)")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN(request_path)")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX(request_path)")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG(request_path)")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF(request_path)")()), + table) val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) From 049be03c5585320c96e4d3fd25433daa15b64d31 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 15 Oct 2024 20:59:19 -0700 Subject: [PATCH 07/12] update scala fmt style Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 256 +++++++++++++----- .../sql/ppl/utils/AggregatorTranslator.java | 4 - .../ppl/utils/FieldSummaryTransformer.java | 13 +- ...eldSummaryCommandTranslatorTestSuite.scala | 56 ++-- 4 files changed, 222 insertions(+), 107 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index fb1988733..1a6a1006d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -8,7 +8,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -36,81 +36,201 @@ class FlintSparkPPLFieldSummaryITSuite } } - ignore("test fieldsummary with single field includefields(status_code) & nulls=true ") { + test("test fieldsummary with single field includefields(status_code) & nulls=true ") { val frame = sql(s""" | source = $testTable | fieldsummary includefields= status_code nulls=true | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 276.0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } - /* - val frame = sql(s""" - | SELECT - | 'status_code' AS Field, - | COUNT(status_code) AS Count, - | COUNT(DISTINCT status_code) AS Distinct, - | MIN(status_code) AS Min, - | MAX(status_code) AS Max, - | AVG(CAST(status_code AS DOUBLE)) AS Avg, - | typeof(status_code) AS Type, - | COUNT(*) - COUNT(status_code) AS Nulls - | FROM $testTable - | GROUP BY typeof(status_code) - | """.stripMargin) - */ - -// val frame = sql(s""" -// | SELECT -// | 'status_code' AS Field, -// | COUNT(status_code) AS Count, -// | COUNT(DISTINCT status_code) AS Distinct, -// | MIN(status_code) AS Min, -// | MAX(status_code) AS Max, -// | AVG(CAST(status_code AS DOUBLE)) AS Avg, -// | typeof(status_code) AS Type, -// | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) -// | FROM ( -// | SELECT status_code, COUNT(*) AS count_status -// | FROM $testTable -// | GROUP BY status_code -// | ORDER BY count_status DESC -// | LIMIT 5 -// | )) AS top_values, -// | COUNT(*) - COUNT(status_code) AS Nulls -// | FROM $testTable -// | GROUP BY typeof(status_code) -// | -// | UNION ALL -// | -// | SELECT -// | 'id' AS Field, -// | COUNT(id) AS Count, -// | COUNT(DISTINCT id) AS Distinct, -// | MIN(id) AS Min, -// | MAX(id) AS Max, -// | AVG(CAST(id AS DOUBLE)) AS Avg, -// | typeof(id) AS Type, -// | (SELECT COLLECT_LIST(STRUCT(id, count_id)) -// | FROM ( -// | SELECT id, COUNT(*) AS count_id -// | FROM $testTable -// | GROUP BY id -// | ORDER BY count_id DESC -// | LIMIT 5 -// | )) AS top_values, -// | COUNT(*) - COUNT(id) AS Nulls -// | FROM $testTable -// | GROUP BY typeof(id) -// |""".stripMargin) + /** + * // val frame = sql(s""" // | SELECT // | 'status_code' AS Field, // | COUNT(status_code) AS + * Count, // | COUNT(DISTINCT status_code) AS Distinct, // | MIN(status_code) AS Min, // | + * MAX(status_code) AS Max, // | AVG(CAST(status_code AS DOUBLE)) AS Avg, // | + * typeof(status_code) AS Type, // | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) // + * \| FROM ( // | SELECT status_code, COUNT(*) AS count_status // | FROM $testTable // | GROUP + * BY status_code // | ORDER BY count_status DESC // | LIMIT 5 // | )) AS top_values, // | + * COUNT(*) - COUNT(status_code) AS Nulls // | FROM $testTable // | GROUP BY typeof(status_code) + * // | // | UNION ALL // | // | SELECT // | 'id' AS Field, // | COUNT(id) AS Count, // | + * COUNT(DISTINCT id) AS Distinct, // | MIN(id) AS Min, // | MAX(id) AS Max, // | AVG(CAST(id AS + * DOUBLE)) AS Avg, // | typeof(id) AS Type, // | (SELECT COLLECT_LIST(STRUCT(id, count_id)) // + * \| FROM ( // | SELECT id, COUNT(*) AS count_id // | FROM $testTable // | GROUP BY id // | + * ORDER BY count_id DESC // | LIMIT 5 // | )) AS top_values, // | COUNT(*) - COUNT(id) AS Nulls + * // | FROM $testTable // | GROUP BY typeof(id) // |""".stripMargin) // Aggregate with + * functions applied to status_code + */ + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=true + | """.stripMargin) val results: Array[Row] = frame.collect() - // Print each row in a readable format + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, "int"), + Row("status_code", 4L, 3L, "200", "403", 276.0, "int"), + Row("request_path", 4L, 3L, "/about", "/home", null, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical - // scalastyle:off println - results.foreach(row => println(row.mkString(", "))) - println(logicalPlan) - // scalastyle:on println -// val expectedPlan = ? -// comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 93e2121d3..cecd04b2d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -28,10 +28,6 @@ * @return */ public interface AggregatorTranslator { - - static String aggregationAlias(BuiltinFunctionName functionName, QualifiedName name) { - return functionName.name()+"("+name.toString()+")"; - } static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index fdaa04934..579919fa3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -32,7 +32,6 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; -import static org.opensearch.sql.ppl.utils.AggregatorTranslator.aggregationAlias; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -106,7 +105,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the count(field) as Count UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); Alias countAlias = Alias$.MODULE$.apply(count, - aggregationAlias(COUNT, field.getField()), + COUNT.name(), NamedExpression.newExprId(), seq(), empty(), @@ -115,7 +114,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the count(DISTINCT field) as CountDistinct UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, - aggregationAlias(COUNT_DISTINCT, field.getField()), + COUNT_DISTINCT.name(), NamedExpression.newExprId(), seq(), empty(), @@ -124,7 +123,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the MAX(field) as MAX UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); Alias maxAlias = Alias$.MODULE$.apply(max, - aggregationAlias(MAX, field.getField()), + MAX.name(), NamedExpression.newExprId(), seq(), empty(), @@ -133,7 +132,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the MIN(field) as Min UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); Alias minAlias = Alias$.MODULE$.apply(min, - aggregationAlias(MIN, field.getField()), + MIN.name(), NamedExpression.newExprId(), seq(), empty(), @@ -142,7 +141,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the AVG(field) as Avg UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); Alias avgAlias = Alias$.MODULE$.apply(avg, - aggregationAlias(AVG, field.getField()), + AVG.name(), NamedExpression.newExprId(), seq(), empty(), @@ -196,7 +195,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Alias for the typeOf(field) as Type UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, - aggregationAlias(TYPEOF, field.getField()), + TYPEOF.name(), NamedExpression.newExprId(), seq(), empty(), diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index fa39a4fd4..5d376e18b 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -40,28 +40,28 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite Alias(Literal("status_code"), "Field")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "COUNT(status_code)")(), + "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT(status_code)")(), + "COUNT_DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MIN(status_code)")(), + "MIN")(), Alias( UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MAX(status_code)")(), + "MAX")(), Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "AVG(status_code)")(), + "AVG")(), Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF(status_code)")()) + "TYPEOF")()) // Define the aggregate plan with alias for TYPEOF in the aggregation val aggregatePlan = Aggregate( groupingExpressions = Seq(Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF(status_code)")()), + "TYPEOF")()), aggregateExpressions, table) val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) @@ -88,27 +88,27 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite Seq( Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), - "TYPEOF(id)")()), + "TYPEOF")()), Seq( Alias(Literal("id"), "Field")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), - "COUNT(id)")(), + "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), - "COUNT_DISTINCT(id)")(), + "COUNT_DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), - "MIN(id)")(), + "MIN")(), Alias( UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), - "MAX(id)")(), + "MAX")(), Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), - "AVG(id)")(), + "AVG")(), Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), - "TYPEOF(id)")()), + "TYPEOF")()), table) val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) @@ -117,7 +117,7 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite val aggregateStatusCodePlan = Aggregate( Seq(Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF(status_code)")()), + "TYPEOF")()), Seq( Alias(Literal("status_code"), "Field")(), Alias( @@ -125,25 +125,25 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite "COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "COUNT(status_code)")(), + "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT(status_code)")(), + "COUNT_DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MIN(status_code)")(), + "MIN")(), Alias( UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MAX(status_code)")(), + "MAX")(), Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "AVG(status_code)")(), + "AVG")(), Alias( UnresolvedFunction( "TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF(status_code)")()), + "TYPEOF")()), table) val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) @@ -155,7 +155,7 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite "TYPEOF", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "TYPEOF(request_path)")()), + "TYPEOF")()), Seq( Alias(Literal("request_path"), "Field")(), Alias( @@ -163,28 +163,28 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite "COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "COUNT(request_path)")(), + "COUNT")(), Alias( UnresolvedFunction( "COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), - "COUNT_DISTINCT(request_path)")(), + "COUNT_DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "MIN(request_path)")(), + "MIN")(), Alias( UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "MAX(request_path)")(), + "MAX")(), Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "AVG(request_path)")(), + "AVG")(), Alias( UnresolvedFunction( "TYPEOF", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "TYPEOF(request_path)")()), + "TYPEOF")()), table) val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) From ea5cbdf14f6388c8371ba95f9df36b95b5ccf710 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 16 Oct 2024 12:50:34 -0700 Subject: [PATCH 08/12] update query with where clause predicate Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 69 ++++++++++++++----- ...eldSummaryCommandTranslatorTestSuite.scala | 53 +++++++++++++- 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index 1a6a1006d..50f325000 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -8,7 +8,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -85,22 +85,57 @@ class FlintSparkPPLFieldSummaryITSuite comparePlans(expectedPlan, logicalPlan, false) } - /** - * // val frame = sql(s""" // | SELECT // | 'status_code' AS Field, // | COUNT(status_code) AS - * Count, // | COUNT(DISTINCT status_code) AS Distinct, // | MIN(status_code) AS Min, // | - * MAX(status_code) AS Max, // | AVG(CAST(status_code AS DOUBLE)) AS Avg, // | - * typeof(status_code) AS Type, // | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) // - * \| FROM ( // | SELECT status_code, COUNT(*) AS count_status // | FROM $testTable // | GROUP - * BY status_code // | ORDER BY count_status DESC // | LIMIT 5 // | )) AS top_values, // | - * COUNT(*) - COUNT(status_code) AS Nulls // | FROM $testTable // | GROUP BY typeof(status_code) - * // | // | UNION ALL // | // | SELECT // | 'id' AS Field, // | COUNT(id) AS Count, // | - * COUNT(DISTINCT id) AS Distinct, // | MIN(id) AS Min, // | MAX(id) AS Max, // | AVG(CAST(id AS - * DOUBLE)) AS Avg, // | typeof(id) AS Type, // | (SELECT COLLECT_LIST(STRUCT(id, count_id)) // - * \| FROM ( // | SELECT id, COUNT(*) AS count_id // | FROM $testTable // | GROUP BY id // | - * ORDER BY count_id DESC // | LIMIT 5 // | )) AS top_values, // | COUNT(*) - COUNT(id) AS Nulls - * // | FROM $testTable // | GROUP BY typeof(id) // |""".stripMargin) // Aggregate with - * functions applied to status_code - */ + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + test( "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { val frame = sql(s""" diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index 5d376e18b..ed0f078c0 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -12,9 +12,9 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Filter, Project, Union} class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite @@ -69,6 +69,55 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test( "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { val context = new CatalystPlanContext From 7a672fc58c21ce7b03656421a7e99763f1941b4d Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 17 Oct 2024 14:44:01 -0700 Subject: [PATCH 09/12] update command and remove the topvalues Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 353 ++++++++++- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 - .../src/main/antlr4/OpenSearchPPLParser.g4 | 1 - .../opensearch/sql/ast/tree/FieldSummary.java | 9 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 21 - .../sql/ppl/parser/AstExpressionBuilder.java | 8 - .../sql/ppl/utils/DataTypeTransformer.java | 6 +- .../ppl/utils/FieldSummaryTransformer.java | 228 ++++--- ...eldSummaryCommandTranslatorTestSuite.scala | 244 -------- ...lPlanFieldSummaryTranslatorTestSuite.scala | 570 ++++++++++++++++++ 10 files changed, 1057 insertions(+), 384 deletions(-) delete mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index 50f325000..8153f18c7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -8,7 +8,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder, Subtract} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -42,7 +42,71 @@ class FlintSparkPPLFieldSummaryITSuite | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = - Array(Row("status_code", 4, 3, 200, 403, 276.0, "int")) + Array(Row("status_code", 4, 3, 200, 403, 184.0, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=false ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 276.0, 2, "int")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -66,6 +130,14 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF")()) @@ -92,7 +164,73 @@ class FlintSparkPPLFieldSummaryITSuite | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = - Array(Row("status_code", 2, 2, 301, 403, 352.0, "int")) + Array(Row("status_code", 2, 2, 301, 403, 352.0, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 0, "int")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -116,6 +254,14 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF")()) @@ -145,9 +291,183 @@ class FlintSparkPPLFieldSummaryITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("id", 6L, 6L, "1", "6", 3.5, "int"), - Row("status_code", 4L, 3L, "200", "403", 276.0, "int"), - Row("request_path", 4L, 3L, "/about", "/home", null, "string")) + Row("id", 6L, 6L, "1", "6", 3.5, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 184.0, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", 0.0, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=false + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 276.0, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", null, 2, "string")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -179,6 +499,11 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), Alias( UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), "TYPEOF")()), @@ -211,6 +536,14 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), Alias( UnresolvedFunction( "TYPEOF", @@ -253,6 +586,14 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), Alias( UnresolvedFunction( "TYPEOF", diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 83bdb185d..aeffffe16 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -90,7 +90,6 @@ NUM: 'NUM'; // FIELDSUMMARY keywords FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; -TOPVALUES: 'TOPVALUES'; NULLS: 'NULLS'; // ARGUMENT KEYWORDS diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index a29c68f87..3afafcebc 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -67,7 +67,6 @@ fieldsummaryCommand fieldsummaryParameter : INCLUDEFIELDS EQUAL fieldList # fieldsummaryIncludeFields - | TOPVALUES EQUAL integerLiteral # fieldsummaryTopValues | NULLS EQUAL booleanLiteral # fieldsummaryNulls ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java index 1d3b9ffed..774a08531 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -20,15 +20,13 @@ import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.TOPVALUES; @Getter @ToString @EqualsAndHashCode(callSuper = false) public class FieldSummary extends UnresolvedPlan { private List includeFields; - private int topValues; - private boolean ignoreNull; + private boolean includeNull; private List collect; private UnresolvedPlan child; @@ -38,10 +36,7 @@ public FieldSummary(List collect) { .forEach(exp -> { switch (((NamedExpression) exp).getExpressionId()) { case NULLS: - this.ignoreNull = (boolean) ((Literal) exp.getChild().get(0)).getValue(); - break; - case TOPVALUES: - this.topValues = (int) ((Literal) exp.getChild().get(0)).getValue(); + this.includeNull = (boolean) ((Literal) exp.getChild().get(0)).getValue(); break; case INCLUDEFIELDS: this.includeFields = ((FieldList) exp.getChild().get(0)).getFieldList(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 8482f4be2..76a7a0c79 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -383,27 +383,6 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { } @Override - /** - * 'Union false, false - * :- 'Aggregate ['typeof('status_code)], [status_code AS Field#20, 'COUNT('status_code) AS Count#21, 'COUNT(distinct 'status_code) AS Distinct#22, 'MIN('status_code) AS Min#23, 'MAX('status_code) AS Max#24, 'AVG(cast('status_code as double)) AS Avg#25, 'typeof('status_code) AS Type#26, scalar-subquery#28 [] AS top_values#29, ('COUNT(1) - 'COUNT('status_code)) AS Nulls#30] - * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, 'status_code, count_status, 'count_status)), None)] - * : : +- 'SubqueryAlias __auto_generated_subquery_name - * : : +- 'GlobalLimit 5 - * : : +- 'LocalLimit 5 - * : : +- 'Sort ['count_status DESC NULLS LAST], true - * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] - * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false - * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false - * +- 'Aggregate ['typeof('id)], [id AS Field#31, 'COUNT('id) AS Count#32, 'COUNT(distinct 'id) AS Distinct#33, 'MIN('id) AS Min#34, 'MAX('id) AS Max#35, 'AVG(cast('id as double)) AS Avg#36, 'typeof('id) AS Type#37, scalar-subquery#39 [] AS top_values#40, ('COUNT(1) - 'COUNT('id)) AS Nulls#41] - * : +- 'Project [unresolvedalias('COLLECT_LIST(struct(id, 'id, count_id, 'count_id)), None)] - * : +- 'SubqueryAlias __auto_generated_subquery_name - * : +- 'GlobalLimit 5 - * : +- 'LocalLimit 5 - * : +- 'Sort ['count_id DESC NULLS LAST], true - * : +- 'Aggregate ['id], ['id, 'COUNT(1) AS count_id#38] - * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false - */ public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getChild().get(0).accept(this, context); return FieldSummaryTransformer.translate(fieldSummary, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 47220174f..28f6cef9c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -55,7 +55,6 @@ import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.TOPVALUES; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; @@ -194,13 +193,6 @@ public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.F return new NamedExpression(INCLUDEFIELDS,new FieldList(includeFields)); } - - @Override - public UnresolvedExpression visitFieldsummaryTopValues(OpenSearchPPLParser.FieldsummaryTopValuesContext ctx) { - return new NamedExpression(TOPVALUES,visitIntegerLiteral(ctx.integerLiteral())); - } - - @Override public UnresolvedExpression visitFieldsummaryNulls(OpenSearchPPLParser.FieldsummaryNullsContext ctx) { return new NamedExpression(NULLS,visitBooleanLiteral(ctx.booleanLiteral())); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 4345b0897..b38148496 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -20,7 +20,10 @@ import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; +import java.util.Arrays; import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -39,7 +42,8 @@ */ public interface DataTypeTransformer { static Seq seq(T... elements) { - return seq(List.of(elements)); + return seq(Arrays.stream(elements).filter(Objects::nonNull) + .collect(Collectors.toList())); } static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index 579919fa3..494cb2265 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -5,20 +5,33 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.AliasIdentifier; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.Alias; import org.apache.spark.sql.catalyst.expressions.Alias$; +import org.apache.spark.sql.catalyst.expressions.Ascending$; import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.Subtract; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LocalLimit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.Sort; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; +import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ppl.CatalystPlanContext; -import scala.Option; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.function.Function; @@ -43,50 +56,13 @@ public interface FieldSummaryTransformer { /** * translate the field summary into the following query: + * source = spark_catalog.default.flint_ppl_test | fieldsummary includefields= id, status_code nulls=true * ----------------------------------------------------- - * // for each column create statement: - * SELECT - * 'columnA' AS Field, - * COUNT(columnA) AS Count, - * COUNT(DISTINCT columnA) AS Distinct, - * MIN(columnA) AS Min, - * MAX(columnA) AS Max, - * AVG(CAST(columnA AS DOUBLE)) AS Avg, - * typeof(columnA) AS Type, - * (SELECT COLLECT_LIST(STRUCT(columnA, count_status)) - * FROM ( - * SELECT columnA, COUNT(*) AS count_status - * FROM $testTable - * GROUP BY columnA - * ORDER BY count_status DESC - * LIMIT 5 - * )) AS top_values, - * COUNT(*) - COUNT(columnA) AS Nulls - * FROM $testTable - * GROUP BY typeof(columnA) - * - * // union all queries - * UNION ALL - * - * SELECT - * 'columnB' AS Field, - * COUNT(columnB) AS Count, - * COUNT(DISTINCT columnB) AS Distinct, - * MIN(columnB) AS Min, - * MAX(columnB) AS Max, - * AVG(CAST(columnB AS DOUBLE)) AS Avg, - * typeof(columnB) AS Type, - * (SELECT COLLECT_LIST(STRUCT(columnB, count_columnB)) - * FROM ( - * SELECT column-, COUNT(*) AS count_column- - * FROM $testTable - * GROUP BY columnB - * ORDER BY count_column- DESC - * LIMIT 5 - * )) AS top_values, - * COUNT(*) - COUNT(columnB) AS Nulls - * FROM $testTable - * GROUP BY typeof(columnB) + * 'Union false, false + * :- 'Aggregate ['typeof('status_code)], [status_code AS Field#20, 'COUNT('status_code) AS Count#21, 'COUNT(distinct 'status_code) AS Distinct#22, 'MIN('status_code) AS Min#23, 'MAX('status_code) AS Max#24, 'AVG(cast('status_code as double)) AS Avg#25, 'typeof('status_code) AS Type#26, scalar-subquery#28 [] AS top_values#29, ('COUNT(1) - 'COUNT('status_code)) AS Nulls#30] + * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * +- 'Aggregate ['typeof('id)], [id AS Field#31, 'COUNT('id) AS Count#32, 'COUNT(distinct 'id) AS Distinct#33, 'MIN('id) AS Min#34, 'MAX('id) AS Max#35, 'AVG(cast('id as double)) AS Avg#36, 'typeof('id) AS Type#37, scalar-subquery#39 [] AS top_values#40, ('COUNT(1) - 'COUNT('id)) AS Nulls#41] + * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { List> aggBranches = fieldSummary.getIncludeFields().stream().map(field -> { @@ -139,58 +115,18 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont seq()); //Alias for the AVG(field) as Avg - UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); - Alias avgAlias = Alias$.MODULE$.apply(avg, - AVG.name(), + Alias avgAlias = getAvgAlias(fieldSummary, fieldLiteral); + + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); + Alias nonNullAlias = Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, NamedExpression.newExprId(), seq(), empty(), seq()); - if (fieldSummary.getTopValues() > 0) { - // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values - CreateNamedStruct structExpr = new CreateNamedStruct(seq( - fieldLiteral, - count - )); - UnresolvedFunction collectList = new UnresolvedFunction( - seq("COLLECT_LIST"), - seq(structExpr), - false, - empty(), - !fieldSummary.isIgnoreNull() - ); - context.getNamedParseExpressions().push( - Alias$.MODULE$.apply( - collectList, - TOP_VALUES, - NamedExpression.newExprId(), - seq(), - empty(), - seq() - )); - } - - if (!fieldSummary.isIgnoreNull()) { - // Alias COUNT(*) - COUNT(column2) AS Nulls - UnresolvedFunction countStar = new UnresolvedFunction( - seq(COUNT.name()), - seq(Literal.create(1, IntegerType)), - false, - empty(), - false - ); - - context.getNamedParseExpressions().push( - Alias$.MODULE$.apply( - new Subtract(countStar, count), - NULLS, - NamedExpression.newExprId(), - seq(), - empty(), - seq() - )); - } //Alias for the typeOf(field) as Type UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); @@ -202,10 +138,112 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont seq()); //Aggregation - return (Function) p -> new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, typeOfAlias), p); - }).collect(Collectors.toList()); + return (Function) p -> + new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, nonNullAlias, typeOfAlias), p); + }).collect(Collectors.toList()); + + return context.applyBranches(aggBranches); + } - LogicalPlan plan = context.applyBranches(aggBranches); - return plan; + /** + * Alias for Avg (if isIncludeNull use COALESCE to replace nulls with zeros) + */ + private static Alias getAvgAlias(FieldSummary fieldSummary, UnresolvedAttribute fieldLiteral) { + UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); + Alias avgAlias = Alias$.MODULE$.apply(avg, + AVG.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + if (fieldSummary.isIncludeNull()) { + UnresolvedFunction coalesceExpr = new UnresolvedFunction( + seq("COALESCE"), + seq(fieldLiteral, Literal.create(0, DataTypes.IntegerType)), + false, + empty(), + false + ); + avg = new UnresolvedFunction(seq(AVG.name()), seq(coalesceExpr), false, empty(), false); + avgAlias = Alias$.MODULE$.apply(avg, + AVG.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + } + return avgAlias; + } + + /** + * top values sub-query + */ + private static Alias topValuesSubQueryAlias(FieldSummary fieldSummary, CatalystPlanContext context, UnresolvedAttribute fieldLiteral, UnresolvedFunction count) { + int topValues = 5;// this value should come from the FieldSummary definition + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + !fieldSummary.isIncludeNull() + ); + Alias topValuesAlias = Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + Project subQueryProject = new Project(seq(topValuesAlias), buildTopValueSubQuery(topValues, fieldLiteral, context)); + ScalarSubquery scalarSubquery = ScalarSubquery$.MODULE$.apply( + subQueryProject, + seq(new ArrayList()), + NamedExpression.newExprId(), + seq(new ArrayList()), + empty(), + empty()); + + return Alias$.MODULE$.apply( + scalarSubquery, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + } + + /** + * inner top values query + * ----------------------------------------------------- + * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, count_status)), None)] + * : : +- 'GlobalLimit 5 + * : : +- 'LocalLimit 5 + * : : +- 'Sort ['count_status DESC NULLS LAST], true + * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] + * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + private static LogicalPlan buildTopValueSubQuery(int topValues,UnresolvedAttribute fieldLiteral, CatalystPlanContext context ) { + //Alias for the count(field) as Count + UnresolvedFunction countFunc = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(countFunc, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + Aggregate aggregate = new Aggregate(seq(fieldLiteral), seq(countAlias), context.getPlan()); + Project project = new Project(seq(fieldLiteral, countAlias), aggregate); + SortOrder sortOrder = new SortOrder(countAlias, Descending$.MODULE$, Ascending$.MODULE$.defaultNullOrdering(), seq()); + Sort sort = new Sort(seq(sortOrder), true, project); + GlobalLimit limit = new GlobalLimit(Literal.create(topValues, IntegerType), new LocalLimit(Literal.create(topValues, IntegerType), sort)); + return new SubqueryAlias(new AliasIdentifier(TOP_VALUES+"_subquery"), limit); } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala deleted file mode 100644 index ed0f078c0..000000000 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.spark.ppl - -import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.scalatest.matchers.should.Matchers - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Filter, Project, Union} - -class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite - extends SparkFunSuite - with PlanTest - with LogicalPlanTestUtils - with Matchers { - - private val planTransformer = new CatalystQueryPlanVisitor() - private val pplParser = new PPLSyntaxParser() - - test("test fieldsummary with single field includefields(status_code) & nulls=true") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit( - plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=true"), - context) - - // Define the table - val table = UnresolvedRelation(Seq("t")) - - // Aggregate with functions applied to status_code - val aggregateExpressions: Seq[NamedExpression] = Seq( - Alias(Literal("status_code"), "Field")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "COUNT")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), - Alias( - UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MIN")(), - Alias( - UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MAX")(), - Alias( - UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "AVG")(), - Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF")()) - - // Define the aggregate plan with alias for TYPEOF in the aggregation - val aggregatePlan = Aggregate( - groupingExpressions = Seq(Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF")()), - aggregateExpressions, - table) - val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) - // Compare the two plans - comparePlans(expectedPlan, logPlan, false) - } - - test( - "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit( - plan( - pplParser, - "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), - context) - - // Define the table - val table = UnresolvedRelation(Seq("t")) - - // Aggregate with functions applied to status_code - val aggregateExpressions: Seq[NamedExpression] = Seq( - Alias(Literal("status_code"), "Field")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "COUNT")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), - Alias( - UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MIN")(), - Alias( - UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MAX")(), - Alias( - UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "AVG")(), - Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF")()) - - val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) - val aggregatePlan = Aggregate( - groupingExpressions = Seq(Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF")()), - aggregateExpressions, - Filter(filterCondition, table)) - - val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) - - // Compare the two plans - comparePlans(expectedPlan, logPlan, false) - } - - test( - "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { - val context = new CatalystPlanContext - val logPlan = - planTransformer.visit( - plan( - pplParser, - "source = t | fieldsummary includefields= id, status_code, request_path nulls=true"), - context) - - // Define the table - val table = UnresolvedRelation(Seq("t")) - - // Aggregate with functions applied to status_code - // Define the aggregate plan with alias for TYPEOF in the aggregation - val aggregateIdPlan = Aggregate( - Seq( - Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), - "TYPEOF")()), - Seq( - Alias(Literal("id"), "Field")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), - "COUNT")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), - "COUNT_DISTINCT")(), - Alias( - UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), - "MIN")(), - Alias( - UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), - "MAX")(), - Alias( - UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), - "AVG")(), - Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), - "TYPEOF")()), - table) - val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) - - // Aggregate with functions applied to status_code - // Define the aggregate plan with alias for TYPEOF in the aggregation - val aggregateStatusCodePlan = Aggregate( - Seq(Alias( - UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "TYPEOF")()), - Seq( - Alias(Literal("status_code"), "Field")(), - Alias( - UnresolvedFunction( - "COUNT", - Seq(UnresolvedAttribute("status_code")), - isDistinct = false), - "COUNT")(), - Alias( - UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), - Alias( - UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MIN")(), - Alias( - UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "MAX")(), - Alias( - UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), - "AVG")(), - Alias( - UnresolvedFunction( - "TYPEOF", - Seq(UnresolvedAttribute("status_code")), - isDistinct = false), - "TYPEOF")()), - table) - val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) - - // Define the aggregate plan with alias for TYPEOF in the aggregation - val aggregatePlan = Aggregate( - Seq( - Alias( - UnresolvedFunction( - "TYPEOF", - Seq(UnresolvedAttribute("request_path")), - isDistinct = false), - "TYPEOF")()), - Seq( - Alias(Literal("request_path"), "Field")(), - Alias( - UnresolvedFunction( - "COUNT", - Seq(UnresolvedAttribute("request_path")), - isDistinct = false), - "COUNT")(), - Alias( - UnresolvedFunction( - "COUNT", - Seq(UnresolvedAttribute("request_path")), - isDistinct = true), - "COUNT_DISTINCT")(), - Alias( - UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "MIN")(), - Alias( - UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "MAX")(), - Alias( - UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), - "AVG")(), - Alias( - UnresolvedFunction( - "TYPEOF", - Seq(UnresolvedAttribute("request_path")), - isDistinct = false), - "TYPEOF")()), - table) - val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) - - val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) - // Compare the two plans - comparePlans(expectedPlan, logPlan, false) - } -} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala new file mode 100644 index 000000000..3580b9a8d --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala @@ -0,0 +1,570 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not, Subtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project, Union} + +class PPLLogicalPlanFieldSummaryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test fieldsummary with single field includefields(status_code) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } +} From 92369996dbd724e85ca05da22a7a1a94ebd8e6eb Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 17 Oct 2024 20:14:44 -0700 Subject: [PATCH 10/12] update command docs Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 6 + docs/ppl-lang/ppl-fieldsummary-command.md | 83 +++++++++ .../FlintSparkPPLFieldSummaryITSuite.scala | 159 ++++++++++++++++-- .../function/BuiltinFunctionName.java | 2 + .../sql/ppl/utils/AggregatorTranslator.java | 4 + .../ppl/utils/FieldSummaryTransformer.java | 34 ++-- ...lPlanFieldSummaryTranslatorTestSuite.scala | 139 +++++++++++++++ 7 files changed, 401 insertions(+), 26 deletions(-) create mode 100644 docs/ppl-lang/ppl-fieldsummary-command.md diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 96eeef726..d7f501dce 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -28,6 +28,12 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | eval b1 = b + 1 | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) - `source = table | eval b1 = lower(b) | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) +**Field-Summary** +[See additional command details](ppl-fieldsummary-command.md) +- `source = t | fieldsummary includefields=status_code nulls=false` +- `source = t | fieldsummary includefields= id, status_code, request_path nulls=true` +- `source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true` + **Nested-Fields** - `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1` - `source = catalog.table | where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield` diff --git a/docs/ppl-lang/ppl-fieldsummary-command.md b/docs/ppl-lang/ppl-fieldsummary-command.md new file mode 100644 index 000000000..3cf1348e2 --- /dev/null +++ b/docs/ppl-lang/ppl-fieldsummary-command.md @@ -0,0 +1,83 @@ +## PPL `fieldsummary` command + +**Description** +Using `fieldsummary` command to : + - Calculate basic statistics for each field (count, distinct count, min, max, avg, stddev, mean ) + - Determine the data type of each field + +**Syntax** + +`... | fieldsummary (nulls=true/false)` + +* command accepts any preceding pipe before the terminal `fieldsummary` command and will take them into account. +* `includefields`: list of all the columns to be collected with statistics into a unified result set +* `nulls`: optional; if the true, include the null values in the aggregation calculations (replace null with zero for numeric values) + +### Example 1: + +PPL query: + + os> source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fiels | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 2 | 2 | 301 | 403 | 352.0 | 352.0 | 72.12489168102785 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Example 2: + +PPL query: + + os> source = t | fieldsummary includefields= id, status_code, request_path nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fiels | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "id" | 6 | 6 | 1 | 6 | 3.5 | 3.5 | 1.8708286933869707 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 4 | 3 | 200 | 403 | 184.0 | 184.0 | 161.16699413961905 | 2 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "request_path" | 2 | 2 | /about| /home | 0.0 | 0.0 | 0 | 2 |"string"| + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Additional Info +The actual query is translated into the following SQL-like statement: + +```sql + SELECT + id AS Field, + COUNT(id) AS COUNT, + COUNT(DISTINCT id) AS COUNT_DISTINCT, + MIN(id) AS MIN, + MAX(id) AS MAX, + AVG(id) AS AVG, + MEAN(id) AS MEAN, + STDDEV(id) AS STDDEV, + (COUNT(1) - COUNT(id)) AS Nulls, + TYPEOF(id) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +UNION + SELECT + status_code AS Field, + COUNT(status_code) AS COUNT, + COUNT(DISTINCT status_code) AS COUNT_DISTINCT, + MIN(status_code) AS MIN, + MAX(status_code) AS MAX, + AVG(status_code) AS AVG, + MEAN(status_code) AS MEAN, + STDDEV(status_code) AS STDDEV, + (COUNT(1) - COUNT(status_code)) AS Nulls, + TYPEOF(status_code) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +``` +For each such columns (id, status_code) there will be a unique statement and all the fields will be presented togather in the result using a UNION operator + + +### Limitation: + - `topvalues` option was removed from this command due the possible performance impact of such sub-query. As an alternative one can use the `top` command directly as shown [here](ppl-top-command.md). + diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index 8153f18c7..44997fa94 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -42,7 +42,7 @@ class FlintSparkPPLFieldSummaryITSuite | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = - Array(Row("status_code", 4, 3, 200, 403, 184.0, 2, "int")) + Array(Row("status_code", 4, 3, 200, 403, 184.0, 184.0, 161.16699413961905, 2, "int")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -73,6 +73,26 @@ class FlintSparkPPLFieldSummaryITSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -106,7 +126,7 @@ class FlintSparkPPLFieldSummaryITSuite | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = - Array(Row("status_code", 4, 3, 200, 403, 276.0, 2, "int")) + Array(Row("status_code", 4, 3, 200, 403, 276.0, 276.0, 97.1356439899038, 2, "int")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -130,6 +150,12 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -164,7 +190,7 @@ class FlintSparkPPLFieldSummaryITSuite | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = - Array(Row("status_code", 2, 2, 301, 403, 352.0, 0, "int")) + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -195,6 +221,26 @@ class FlintSparkPPLFieldSummaryITSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -230,7 +276,7 @@ class FlintSparkPPLFieldSummaryITSuite | """.stripMargin) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = - Array(Row("status_code", 2, 2, 301, 403, 352.0, 0, "int")) + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -254,6 +300,12 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -291,9 +343,9 @@ class FlintSparkPPLFieldSummaryITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("id", 6L, 6L, "1", "6", 3.5, 0, "int"), - Row("status_code", 4L, 3L, "200", "403", 184.0, 2, "int"), - Row("request_path", 4L, 3L, "/about", "/home", 0.0, 2, "string")) + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 184.0, 184.0, 161.16699413961905, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", 0.0, 0.0, 0.0, 2, "string")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -332,6 +384,26 @@ class FlintSparkPPLFieldSummaryITSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -376,6 +448,26 @@ class FlintSparkPPLFieldSummaryITSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -433,6 +525,26 @@ class FlintSparkPPLFieldSummaryITSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -465,9 +577,9 @@ class FlintSparkPPLFieldSummaryITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( - Row("id", 6L, 6L, "1", "6", 3.5, 0, "int"), - Row("status_code", 4L, 3L, "200", "403", 276.0, 2, "int"), - Row("request_path", 4L, 3L, "/about", "/home", null, 2, "string")) + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 276.0, 276.0, 97.1356439899038, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", null, null, null, 2, "string")) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -499,6 +611,12 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -536,6 +654,15 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -586,6 +713,18 @@ class FlintSparkPPLFieldSummaryITSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 1f58f92d1..d9ca0c7c2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -158,6 +158,8 @@ public enum BuiltinFunctionName { /** Aggregation Function. */ AVG(FunctionName.of("avg")), + MEAN(FunctionName.of("mean")), + STDDEV(FunctionName.of("stddev")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), COUNT_DISTINCT(FunctionName.of("count_distinct")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index cecd04b2d..a01b38a80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -41,12 +41,16 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false); case MIN: return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false); + case MEAN: + return new UnresolvedFunction(seq("MEAN"), seq(arg), distinct, empty(),false); case AVG: return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false); case COUNT: return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false); case SUM: return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false); + case STDDEV: + return new UnresolvedFunction(seq("STDDEV"), seq(arg), distinct, empty(),false); case STDDEV_POP: return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false); case STDDEV_SAMP: diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index 494cb2265..7eb4ece39 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.tree.FieldSummary; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystPlanContext; import java.util.ArrayList; @@ -43,7 +44,9 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT_DISTINCT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MEAN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -55,14 +58,7 @@ public interface FieldSummaryTransformer { String FIELD = "Field"; /** - * translate the field summary into the following query: - * source = spark_catalog.default.flint_ppl_test | fieldsummary includefields= id, status_code nulls=true - * ----------------------------------------------------- - * 'Union false, false - * :- 'Aggregate ['typeof('status_code)], [status_code AS Field#20, 'COUNT('status_code) AS Count#21, 'COUNT(distinct 'status_code) AS Distinct#22, 'MIN('status_code) AS Min#23, 'MAX('status_code) AS Max#24, 'AVG(cast('status_code as double)) AS Avg#25, 'typeof('status_code) AS Type#26, scalar-subquery#28 [] AS top_values#29, ('COUNT(1) - 'COUNT('status_code)) AS Nulls#30] - * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false - * +- 'Aggregate ['typeof('id)], [id AS Field#31, 'COUNT('id) AS Count#32, 'COUNT(distinct 'id) AS Distinct#33, 'MIN('id) AS Min#34, 'MAX('id) AS Max#35, 'AVG(cast('id as double)) AS Avg#36, 'typeof('id) AS Type#37, scalar-subquery#39 [] AS top_values#40, ('COUNT(1) - 'COUNT('id)) AS Nulls#41] - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * translate the command into the aggregate statement group by the column name */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { List> aggBranches = fieldSummary.getIncludeFields().stream().map(field -> { @@ -115,7 +111,13 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont seq()); //Alias for the AVG(field) as Avg - Alias avgAlias = getAvgAlias(fieldSummary, fieldLiteral); + Alias avgAlias = getAggMethodAlias(AVG, fieldSummary, fieldLiteral); + + //Alias for the MEAN(field) as Mean + Alias meanAlias = getAggMethodAlias(MEAN, fieldSummary, fieldLiteral); + + //Alias for the STDDEV(field) as Stddev + Alias stddevAlias = getAggMethodAlias(STDDEV, fieldSummary, fieldLiteral); // Alias COUNT(*) - COUNT(column2) AS Nulls UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); @@ -139,19 +141,19 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont //Aggregation return (Function) p -> - new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, nonNullAlias, typeOfAlias), p); + new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, meanAlias, stddevAlias, nonNullAlias, typeOfAlias), p); }).collect(Collectors.toList()); return context.applyBranches(aggBranches); } /** - * Alias for Avg (if isIncludeNull use COALESCE to replace nulls with zeros) + * Alias for aggregate function (if isIncludeNull use COALESCE to replace nulls with zeros) */ - private static Alias getAvgAlias(FieldSummary fieldSummary, UnresolvedAttribute fieldLiteral) { - UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); + private static Alias getAggMethodAlias(BuiltinFunctionName method, FieldSummary fieldSummary, UnresolvedAttribute fieldLiteral) { + UnresolvedFunction avg = new UnresolvedFunction(seq(method.name()), seq(fieldLiteral), false, empty(), false); Alias avgAlias = Alias$.MODULE$.apply(avg, - AVG.name(), + method.name(), NamedExpression.newExprId(), seq(), empty(), @@ -165,9 +167,9 @@ private static Alias getAvgAlias(FieldSummary fieldSummary, UnresolvedAttribute empty(), false ); - avg = new UnresolvedFunction(seq(AVG.name()), seq(coalesceExpr), false, empty(), false); + avg = new UnresolvedFunction(seq(method.name()), seq(coalesceExpr), false, empty(), false); avgAlias = Alias$.MODULE$.apply(avg, - AVG.name(), + method.name(), NamedExpression.newExprId(), seq(), empty(), diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala index 3580b9a8d..4edc9bee7 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala @@ -53,6 +53,12 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -112,6 +118,26 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -174,6 +200,26 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -231,6 +277,12 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -301,6 +353,26 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -345,6 +417,26 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -401,6 +493,26 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite isDistinct = false)), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -460,6 +572,12 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -497,6 +615,15 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), @@ -546,6 +673,18 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite Alias( UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), Alias( Subtract( UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), From 132978bc9fe166d38f34dcb35f673da03469496a Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 22 Oct 2024 15:33:06 -0700 Subject: [PATCH 11/12] update with comments feedback Signed-off-by: YANGDB --- .../FlintSparkPPLFieldSummaryITSuite.scala | 20 +-- .../sql/ast/expression/NamedExpression.java | 30 ---- .../opensearch/sql/ast/tree/FieldSummary.java | 28 ++- .../function/BuiltinFunctionName.java | 1 - .../sql/ppl/parser/AstExpressionBuilder.java | 9 +- .../sql/ppl/utils/DataTypeTransformer.java | 4 +- .../ppl/utils/FieldSummaryTransformer.java | 170 +++++++++--------- ...lPlanFieldSummaryTranslatorTestSuite.scala | 20 +-- 8 files changed, 122 insertions(+), 160 deletions(-) delete mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index 44997fa94..5a5990001 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -56,7 +56,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -140,7 +140,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -204,7 +204,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -290,7 +290,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -367,7 +367,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), "MIN")(), @@ -431,7 +431,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -508,7 +508,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MIN")(), @@ -601,7 +601,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), "MIN")(), @@ -644,7 +644,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -703,7 +703,7 @@ class FlintSparkPPLFieldSummaryITSuite "COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MIN")(), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java deleted file mode 100644 index 4fee68a09..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/NamedExpression.java +++ /dev/null @@ -1,30 +0,0 @@ -package org.opensearch.sql.ast.expression; - -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import lombok.ToString; -import org.opensearch.sql.ast.AbstractNodeVisitor; - -import java.util.Arrays; -import java.util.List; - -@Getter -@ToString -@RequiredArgsConstructor -@EqualsAndHashCode(callSuper = false) -public class NamedExpression extends UnresolvedExpression { - private final int expressionId; - private final UnresolvedExpression expression; - - // private final DataType valueType; - @Override - public List getChild() { - return Arrays.asList(expression); - } - - @Override - public R accept(AbstractNodeVisitor nodeVisitor, C context) { - return nodeVisitor.visit(this, context); - } -} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java index 774a08531..a8072e76b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -10,38 +10,30 @@ import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; -import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.FieldList; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.NamedExpression; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; - @Getter @ToString @EqualsAndHashCode(callSuper = false) public class FieldSummary extends UnresolvedPlan { - private List includeFields; + private List includeFields; private boolean includeNull; private List collect; private UnresolvedPlan child; public FieldSummary(List collect) { this.collect = collect; - collect.stream().filter(e->e instanceof NamedExpression) - .forEach(exp -> { - switch (((NamedExpression) exp).getExpressionId()) { - case NULLS: - this.includeNull = (boolean) ((Literal) exp.getChild().get(0)).getValue(); - break; - case INCLUDEFIELDS: - this.includeFields = ((FieldList) exp.getChild().get(0)).getFieldList(); - break; - } + collect.forEach(exp -> { + if (exp instanceof Argument) { + this.includeNull = (boolean) ((Argument)exp).getValue().getValue(); + } + if (exp instanceof AttributeList) { + this.includeFields = ((AttributeList)exp).getAttrList(); + } }); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 0167ee667..bb0a5aeb5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -168,7 +168,6 @@ public enum BuiltinFunctionName { STDDEV(FunctionName.of("stddev")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), - COUNT_DISTINCT(FunctionName.of("count_distinct")), MIN(FunctionName.of("min")), MAX(FunctionName.of("max")), // sample variance diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 28f6cef9c..ea51ca7a1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -16,6 +16,7 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; @@ -30,7 +31,6 @@ import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.NamedExpression; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -186,16 +186,15 @@ public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext @Override public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.FieldsummaryIncludeFieldsContext ctx) { - List includeFields = ctx.fieldList().fieldExpression().stream() + List list = ctx.fieldList().fieldExpression().stream() .map(this::visitFieldExpression) - .map(p->(Field)p) .collect(Collectors.toList()); - return new NamedExpression(INCLUDEFIELDS,new FieldList(includeFields)); + return new AttributeList(list); } @Override public UnresolvedExpression visitFieldsummaryNulls(OpenSearchPPLParser.FieldsummaryNullsContext ctx) { - return new NamedExpression(NULLS,visitBooleanLiteral(ctx.booleanLiteral())); + return new Argument("NULLS",(Literal)visitBooleanLiteral(ctx.booleanLiteral())); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index b38148496..62eef90ed 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -42,9 +42,9 @@ */ public interface DataTypeTransformer { static Seq seq(T... elements) { - return seq(Arrays.stream(elements).filter(Objects::nonNull) - .collect(Collectors.toList())); + return seq(List.of(elements)); } + static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index 7eb4ece39..dd8f01874 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Sort; import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystPlanContext; @@ -42,7 +43,6 @@ import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT_DISTINCT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MEAN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; @@ -58,91 +58,93 @@ public interface FieldSummaryTransformer { String FIELD = "Field"; /** - * translate the command into the aggregate statement group by the column name + * translate the command into the aggregate statement group by the column name */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { - List> aggBranches = fieldSummary.getIncludeFields().stream().map(field -> { - Literal fieldNameLiteral = Literal.create(field.getField().toString(), StringType); - UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(field.getField().getParts())); + List> aggBranches = fieldSummary.getIncludeFields().stream() + .filter(field -> field instanceof org.opensearch.sql.ast.expression.Field ) + .map(field -> { + Literal fieldNameLiteral = Literal.create(((org.opensearch.sql.ast.expression.Field)field).getField().toString(), StringType); + UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(((org.opensearch.sql.ast.expression.Field)field).getField().getParts())); context.withProjectedFields(Collections.singletonList(field)); - // Alias for the field name as Field - Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, - FIELD, - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - //Alias for the count(field) as Count - UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); - Alias countAlias = Alias$.MODULE$.apply(count, - COUNT.name(), - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - //Alias for the count(DISTINCT field) as CountDistinct - UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); - Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, - COUNT_DISTINCT.name(), - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - //Alias for the MAX(field) as MAX - UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); - Alias maxAlias = Alias$.MODULE$.apply(max, - MAX.name(), - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - //Alias for the MIN(field) as Min - UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); - Alias minAlias = Alias$.MODULE$.apply(min, - MIN.name(), - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - //Alias for the AVG(field) as Avg - Alias avgAlias = getAggMethodAlias(AVG, fieldSummary, fieldLiteral); - - //Alias for the MEAN(field) as Mean - Alias meanAlias = getAggMethodAlias(MEAN, fieldSummary, fieldLiteral); - - //Alias for the STDDEV(field) as Stddev - Alias stddevAlias = getAggMethodAlias(STDDEV, fieldSummary, fieldLiteral); - - // Alias COUNT(*) - COUNT(column2) AS Nulls - UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); - Alias nonNullAlias = Alias$.MODULE$.apply( - new Subtract(countStar, count), - NULLS, - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - - //Alias for the typeOf(field) as Type - UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); - Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, - TYPEOF.name(), - NamedExpression.newExprId(), - seq(), - empty(), - seq()); - - //Aggregation - return (Function) p -> - new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, meanAlias, stddevAlias, nonNullAlias, typeOfAlias), p); - }).collect(Collectors.toList()); + // Alias for the field name as Field + Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, + FIELD, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(field) as Count + UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(count, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(DISTINCT field) as CountDistinct + UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); + Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, + "DISTINCT", + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MAX(field) as MAX + UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); + Alias maxAlias = Alias$.MODULE$.apply(max, + MAX.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MIN(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); + Alias minAlias = Alias$.MODULE$.apply(min, + MIN.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the AVG(field) as Avg + Alias avgAlias = getAggMethodAlias(AVG, fieldSummary, fieldLiteral); + + //Alias for the MEAN(field) as Mean + Alias meanAlias = getAggMethodAlias(MEAN, fieldSummary, fieldLiteral); + + //Alias for the STDDEV(field) as Stddev + Alias stddevAlias = getAggMethodAlias(STDDEV, fieldSummary, fieldLiteral); + + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); + Alias nonNullAlias = Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + + //Alias for the typeOf(field) as Type + UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); + Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, + TYPEOF.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Aggregation + return (Function) p -> + new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, meanAlias, stddevAlias, nonNullAlias, typeOfAlias), p); + }).collect(Collectors.toList()); return context.applyBranches(aggBranches); } @@ -232,7 +234,7 @@ private static Alias topValuesSubQueryAlias(FieldSummary fieldSummary, CatalystP * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false */ - private static LogicalPlan buildTopValueSubQuery(int topValues,UnresolvedAttribute fieldLiteral, CatalystPlanContext context ) { + private static LogicalPlan buildTopValueSubQuery(int topValues, UnresolvedAttribute fieldLiteral, CatalystPlanContext context) { //Alias for the count(field) as Count UnresolvedFunction countFunc = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); Alias countAlias = Alias$.MODULE$.apply(countFunc, @@ -246,6 +248,6 @@ private static LogicalPlan buildTopValueSubQuery(int topValues,UnresolvedAttribu SortOrder sortOrder = new SortOrder(countAlias, Descending$.MODULE$, Ascending$.MODULE$.defaultNullOrdering(), seq()); Sort sort = new Sort(seq(sortOrder), true, project); GlobalLimit limit = new GlobalLimit(Literal.create(topValues, IntegerType), new LocalLimit(Literal.create(topValues, IntegerType), sort)); - return new SubqueryAlias(new AliasIdentifier(TOP_VALUES+"_subquery"), limit); + return new SubqueryAlias(new AliasIdentifier(TOP_VALUES + "_subquery"), limit); } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala index 4edc9bee7..c14e1f6cf 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala @@ -43,7 +43,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -101,7 +101,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -183,7 +183,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -267,7 +267,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -336,7 +336,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), "MIN")(), @@ -400,7 +400,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -476,7 +476,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MIN")(), @@ -562,7 +562,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), "MIN")(), @@ -605,7 +605,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT")(), Alias( UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN")(), @@ -663,7 +663,7 @@ class PPLLogicalPlanFieldSummaryTranslatorTestSuite "COUNT", Seq(UnresolvedAttribute("request_path")), isDistinct = true), - "COUNT_DISTINCT")(), + "DISTINCT")(), Alias( UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), "MIN")(), From e3005d46d816f5a1933b056396942c22d7fa15e9 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 23 Oct 2024 18:14:08 -0700 Subject: [PATCH 12/12] update `FIELD SUMMARY` symbols to the keywordsCanBeId bag of words Signed-off-by: YANGDB --- docs/ppl-lang/ppl-fieldsummary-command.md | 4 ++-- ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/ppl-lang/ppl-fieldsummary-command.md b/docs/ppl-lang/ppl-fieldsummary-command.md index 3cf1348e2..468c2046b 100644 --- a/docs/ppl-lang/ppl-fieldsummary-command.md +++ b/docs/ppl-lang/ppl-fieldsummary-command.md @@ -19,7 +19,7 @@ PPL query: os> source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| - | Fiels | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| | "status_code" | 2 | 2 | 301 | 403 | 352.0 | 352.0 | 72.12489168102785 | 0 | "int" | +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| @@ -30,7 +30,7 @@ PPL query: os> source = t | fieldsummary includefields= id, status_code, request_path nulls=true +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| - | Fiels | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| | "id" | 6 | 6 | 1 | 6 | 3.5 | 3.5 | 1.8708286933869707 | 0 | "int" | +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 7ca53a652..742ae2a30 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -1071,6 +1071,10 @@ keywordsCanBeId | SPARKLINE | C | DC + // FIELD SUMMARY + | FIELDSUMMARY + | INCLUDEFIELDS + | NULLS // JOIN TYPE | OUTER | INNER