diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index fc2e8caac..db00adae5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -36,62 +36,79 @@ class FlintSparkPPLFieldSummaryITSuite } } - test("test fillnull with one null replacement value and one column") { -// val frame = sql(s""" -// | source = $testTable | fieldsummary includefields= status_code, id, response_time topvalues=5 nulls=true -// | """.stripMargin) - + test("test fieldsummary with single field includefields(status_code) & nulls=true ") { val frame = sql(s""" - | SELECT - | 'status_code' AS Field, - | COUNT(status_code) AS Count, - | COUNT(DISTINCT status_code) AS Distinct, - | MIN(status_code) AS Min, - | MAX(status_code) AS Max, - | AVG(CAST(status_code AS DOUBLE)) AS Avg, - | typeof(status_code) AS Type, - | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) - | FROM ( - | SELECT status_code, COUNT(*) AS count_status - | FROM $testTable - | GROUP BY status_code - | ORDER BY count_status DESC - | LIMIT 5 - | )) AS top_values, - | COUNT(*) - COUNT(status_code) AS Nulls - | FROM $testTable - | GROUP BY typeof(status_code) - | - | UNION ALL - | - | SELECT - | 'id' AS Field, - | COUNT(id) AS Count, - | COUNT(DISTINCT id) AS Distinct, - | MIN(id) AS Min, - | MAX(id) AS Max, - | AVG(CAST(id AS DOUBLE)) AS Avg, - | typeof(id) AS Type, - | (SELECT COLLECT_LIST(STRUCT(id, count_id)) - | FROM ( - | SELECT id, COUNT(*) AS count_id - | FROM $testTable - | GROUP BY id - | ORDER BY count_id DESC - | LIMIT 5 - | )) AS top_values, - | COUNT(*) - COUNT(id) AS Nulls - | FROM $testTable - | GROUP BY typeof(id) - |""".stripMargin) + | source = $testTable | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + +/* + val frame = sql(s""" + | SELECT + | 'status_code' AS Field, + | COUNT(status_code) AS Count, + | COUNT(DISTINCT status_code) AS Distinct, + | MIN(status_code) AS Min, + | MAX(status_code) AS Max, + | AVG(CAST(status_code AS DOUBLE)) AS Avg, + | typeof(status_code) AS Type, + | COUNT(*) - COUNT(status_code) AS Nulls + | FROM $testTable + | GROUP BY typeof(status_code) + | """.stripMargin) +*/ + +// val frame = sql(s""" +// | SELECT +// | 'status_code' AS Field, +// | COUNT(status_code) AS Count, +// | COUNT(DISTINCT status_code) AS Distinct, +// | MIN(status_code) AS Min, +// | MAX(status_code) AS Max, +// | AVG(CAST(status_code AS DOUBLE)) AS Avg, +// | typeof(status_code) AS Type, +// | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) +// | FROM ( +// | SELECT status_code, COUNT(*) AS count_status +// | FROM $testTable +// | GROUP BY status_code +// | ORDER BY count_status DESC +// | LIMIT 5 +// | )) AS top_values, +// | COUNT(*) - COUNT(status_code) AS Nulls +// | FROM $testTable +// | GROUP BY typeof(status_code) +// | +// | UNION ALL +// | +// | SELECT +// | 'id' AS Field, +// | COUNT(id) AS Count, +// | COUNT(DISTINCT id) AS Distinct, +// | MIN(id) AS Min, +// | MAX(id) AS Max, +// | AVG(CAST(id AS DOUBLE)) AS Avg, +// | typeof(id) AS Type, +// | (SELECT COLLECT_LIST(STRUCT(id, count_id)) +// | FROM ( +// | SELECT id, COUNT(*) AS count_id +// | FROM $testTable +// | GROUP BY id +// | ORDER BY count_id DESC +// | LIMIT 5 +// | )) AS top_values, +// | COUNT(*) - COUNT(id) AS Nulls +// | FROM $testTable +// | GROUP BY typeof(id) +// |""".stripMargin) val results: Array[Row] = frame.collect() // Print each row in a readable format + val logicalPlan: LogicalPlan = frame.queryExecution.logical // scalastyle:off println results.foreach(row => println(row.mkString(", "))) + println(logicalPlan) // scalastyle:on println -// val logicalPlan: LogicalPlan = frame.queryExecution.logical // val expectedPlan = ? // comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java index a442d77fa..1d3b9ffed 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -8,7 +8,6 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Field; @@ -18,7 +17,6 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; -import java.util.stream.Collectors; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; @@ -30,7 +28,7 @@ public class FieldSummary extends UnresolvedPlan { private List includeFields; private int topValues; - private boolean nulls; + private boolean ignoreNull; private List collect; private UnresolvedPlan child; @@ -40,7 +38,7 @@ public FieldSummary(List collect) { .forEach(exp -> { switch (((NamedExpression) exp).getExpressionId()) { case NULLS: - this.nulls = (boolean) ((Literal) exp.getChild().get(0)).getValue(); + this.ignoreNull = (boolean) ((Literal) exp.getChild().get(0)).getValue(); break; case TOPVALUES: this.topValues = (int) ((Literal) exp.getChild().get(0)).getValue(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 6b549663a..1f58f92d1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -160,6 +160,7 @@ public enum BuiltinFunctionName { AVG(FunctionName.of("avg")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), + COUNT_DISTINCT(FunctionName.of("count_distinct")), MIN(FunctionName.of("min")), MAX(FunctionName.of("max")), // sample variance diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 76a7a0c79..8482f4be2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -383,6 +383,27 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { } @Override + /** + * 'Union false, false + * :- 'Aggregate ['typeof('status_code)], [status_code AS Field#20, 'COUNT('status_code) AS Count#21, 'COUNT(distinct 'status_code) AS Distinct#22, 'MIN('status_code) AS Min#23, 'MAX('status_code) AS Max#24, 'AVG(cast('status_code as double)) AS Avg#25, 'typeof('status_code) AS Type#26, scalar-subquery#28 [] AS top_values#29, ('COUNT(1) - 'COUNT('status_code)) AS Nulls#30] + * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, 'status_code, count_status, 'count_status)), None)] + * : : +- 'SubqueryAlias __auto_generated_subquery_name + * : : +- 'GlobalLimit 5 + * : : +- 'LocalLimit 5 + * : : +- 'Sort ['count_status DESC NULLS LAST], true + * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] + * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * +- 'Aggregate ['typeof('id)], [id AS Field#31, 'COUNT('id) AS Count#32, 'COUNT(distinct 'id) AS Distinct#33, 'MIN('id) AS Min#34, 'MAX('id) AS Max#35, 'AVG(cast('id as double)) AS Avg#36, 'typeof('id) AS Type#37, scalar-subquery#39 [] AS top_values#40, ('COUNT(1) - 'COUNT('id)) AS Nulls#41] + * : +- 'Project [unresolvedalias('COLLECT_LIST(struct(id, 'id, count_id, 'count_id)), None)] + * : +- 'SubqueryAlias __auto_generated_subquery_name + * : +- 'GlobalLimit 5 + * : +- 'LocalLimit 5 + * : +- 'Sort ['count_id DESC NULLS LAST], true + * : +- 'Aggregate ['id], ['id, 'COUNT(1) AS count_id#38] + * : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getChild().get(0).accept(this, context); return FieldSummaryTransformer.translate(fieldSummary, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 3c367a948..93e2121d3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -12,10 +12,12 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; +import java.util.Optional; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -27,30 +29,36 @@ */ public interface AggregatorTranslator { + static String aggregationAlias(BuiltinFunctionName functionName, QualifiedName name) { + return functionName.name()+"("+name.toString()+")"; + } + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); + boolean distinct = aggregateFunction.getDistinct(); // Additional aggregation function operators will be added here - switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { + BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get(); + switch (functionName) { case MAX: - return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false); case MIN: - return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false); case AVG: - return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false); case COUNT: - return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false); case SUM: - return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false); case STDDEV_POP: - return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false); case STDDEV_SAMP: - return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), distinct, empty(),false); case PERCENTILE: - return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: - return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java index 225fc7a83..9cc0582f3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -5,33 +5,38 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Alias; +import org.apache.spark.sql.catalyst.expressions.Alias$; import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Subtract; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ppl.CatalystPlanContext; import scala.Option; + import java.util.Collections; import static org.apache.spark.sql.types.DataTypes.IntegerType; import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT_DISTINCT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; +import static org.opensearch.sql.ppl.utils.AggregatorTranslator.aggregationAlias; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; public interface FieldSummaryTransformer { - String COUNT = "Count"; - String COUNT_DISTINCT = "CountDistinct"; - String MAX = "Max"; - String MIN = "Min"; - String AVG = "Avg"; - String TYPE = "Type"; String TOP_VALUES = "TopValues"; String NULLS = "Nulls"; - String FIELD = "Field"; /** @@ -39,147 +44,135 @@ public interface FieldSummaryTransformer { * ----------------------------------------------------- * // for each column create statement: * SELECT - * 'column-1' AS Field, - * COUNT(column-1) AS Count, - * COUNT(DISTINCT column-1) AS Distinct, - * MIN(column-1) AS Min, - * MAX(column-1) AS Max, - * AVG(CAST(column-1 AS DOUBLE)) AS Avg, - * typeof(column-1) AS Type, - * (SELECT COLLECT_LIST(STRUCT(column-1, count_status)) + * 'columnA' AS Field, + * COUNT(columnA) AS Count, + * COUNT(DISTINCT columnA) AS Distinct, + * MIN(columnA) AS Min, + * MAX(columnA) AS Max, + * AVG(CAST(columnA AS DOUBLE)) AS Avg, + * typeof(columnA) AS Type, + * (SELECT COLLECT_LIST(STRUCT(columnA, count_status)) * FROM ( - * SELECT column-1, COUNT(*) AS count_status + * SELECT columnA, COUNT(*) AS count_status * FROM $testTable - * GROUP BY column-1 + * GROUP BY columnA * ORDER BY count_status DESC * LIMIT 5 * )) AS top_values, - * COUNT(*) - COUNT(column-1) AS Nulls + * COUNT(*) - COUNT(columnA) AS Nulls * FROM $testTable - * GROUP BY typeof(column-1) + * GROUP BY typeof(columnA) * * // union all queries * UNION ALL * * SELECT - * 'column-2' AS Field, - * COUNT(column-2) AS Count, - * COUNT(DISTINCT column-2) AS Distinct, - * MIN(column-2) AS Min, - * MAX(column-2) AS Max, - * AVG(CAST(column-2 AS DOUBLE)) AS Avg, - * typeof(column-2) AS Type, - * (SELECT COLLECT_LIST(STRUCT(column-2, count_column-2)) + * 'columnB' AS Field, + * COUNT(columnB) AS Count, + * COUNT(DISTINCT columnB) AS Distinct, + * MIN(columnB) AS Min, + * MAX(columnB) AS Max, + * AVG(CAST(columnB AS DOUBLE)) AS Avg, + * typeof(columnB) AS Type, + * (SELECT COLLECT_LIST(STRUCT(columnB, count_columnB)) * FROM ( * SELECT column-, COUNT(*) AS count_column- * FROM $testTable - * GROUP BY column-2 + * GROUP BY columnB * ORDER BY count_column- DESC * LIMIT 5 * )) AS top_values, - * COUNT(*) - COUNT(column-2) AS Nulls + * COUNT(*) - COUNT(columnB) AS Nulls * FROM $testTable - * GROUP BY typeof(column-2) + * GROUP BY typeof(columnB) */ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { fieldSummary.getIncludeFields().forEach(field -> { - Literal fieldLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(field.getField().toString(), StringType); + Literal fieldNameLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(field.getField().toString(), StringType); + UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(field.getField().getParts())); context.withProjectedFields(Collections.singletonList(field)); - //Alias for the field name as Field - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(fieldLiteral, - FIELD, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + + // Alias for the field name as Field + Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, + FIELD, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); //Alias for the count(field) as Count - UnresolvedFunction count = new UnresolvedFunction(seq("COUNT"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(count, - COUNT, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(count, + aggregationAlias(COUNT,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); //Alias for the count(DISTINCT field) as CountDistinct - UnresolvedFunction countDistinct = new UnresolvedFunction(seq("COUNT"), seq(fieldLiteral), true, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(countDistinct, - COUNT_DISTINCT, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - + UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); + Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, + aggregationAlias(COUNT_DISTINCT,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + //Alias for the MAX(field) as MAX - UnresolvedFunction max = new UnresolvedFunction(seq("MAX"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(max, - MAX, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - - //Alias for the MAX(field) as Min - UnresolvedFunction min = new UnresolvedFunction(seq("MIN"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(min, - MIN, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); + Alias maxAlias = Alias$.MODULE$.apply(max, + aggregationAlias(MAX,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); - //Alias for the AVG(field) as Avg - UnresolvedFunction avg = new UnresolvedFunction(seq("AVG"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(avg, - AVG, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); + //Alias for the MIN(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); + Alias minAlias = Alias$.MODULE$.apply(min, + aggregationAlias(MIN,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); - //Alias for the typeOf(field) as Type - UnresolvedFunction type = new UnresolvedFunction(seq("TYPEOF"), seq(fieldLiteral), false, empty(), false); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(type, - TYPE, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - - // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values - CreateNamedStruct structExpr = new CreateNamedStruct(seq( - fieldLiteral, - count - )); - UnresolvedFunction collectList = new UnresolvedFunction( - seq("COLLECT_LIST"), - seq(structExpr), - false, + //Alias for the AVG(field) as Avg + UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false); + Alias avgAlias = Alias$.MODULE$.apply(avg, + aggregationAlias(AVG,field.getField()), + NamedExpression.newExprId(), + seq(), empty(), - false - ); - context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( - collectList, - TOP_VALUES, - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()) - )); - - if (fieldSummary.isNulls()) { + seq()); + + if (fieldSummary.getTopValues()>0) { + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + !fieldSummary.isIgnoreNull() + ); + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + Option.empty(), + seq() + )); + } + + if (!fieldSummary.isIgnoreNull()) { // Alias COUNT(*) - COUNT(column2) AS Nulls UnresolvedFunction countStar = new UnresolvedFunction( - seq("COUNT"), + seq(COUNT.name()), seq(org.apache.spark.sql.catalyst.expressions.Literal.create(1, IntegerType)), false, empty(), @@ -191,11 +184,24 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont new Subtract(countStar, count), NULLS, NamedExpression.newExprId(), - seq(new java.util.ArrayList()), + seq(), Option.empty(), - seq(new java.util.ArrayList()) + seq() )); } + + //Alias for the typeOf(field) as Type + UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); + Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, + aggregationAlias(TYPEOF,field.getField()), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Aggregation + context.apply(p-> new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, typeOfAlias), p)); + }); return context.getPlan(); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index e90d587ae..f41f9808c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -9,12 +9,11 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Project} class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite @@ -25,32 +24,37 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - ignore("test fieldsummary with `includefields=status_code,user_id,response_time`") { + test("test fieldsummary with single field includefields(status_code) & nulls=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan( pplParser, - "source = t | fieldsummary includefields= status_code, user_id, response_time topvalues=5 nulls=true"), + "source = t | fieldsummary includefields= status_code nulls=true"), context) + // Define the table val table = UnresolvedRelation(Seq("t")) - val renameProjectList: Seq[NamedExpression] = - Seq( - UnresolvedStar(None), - Alias( - UnresolvedFunction( - "coalesce", - Seq(UnresolvedAttribute("column_name"), Literal("null replacement value")), - isDistinct = false), - "column_name")()) - val renameProject = Project(renameProjectList, table) - - val dropSourceColumn = - DataFrameDropColumns(Seq(UnresolvedAttribute("column_name")), renameProject) - - val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) - comparePlans(expectedPlan, logPlan, checkAnalysis = false) + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "COUNT(status_code)")(), + Alias(UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), "COUNT_DISTINCT(status_code)")(), + Alias(UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MIN(status_code)")(), + Alias(UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "MAX(status_code)")(), + Alias(UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "AVG(status_code)")(), + Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")() + ) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias(UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), "TYPEOF(status_code)")()), + aggregateExpressions, + table + ) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logPlan)) } }