diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index d161613a6..d22fc7b63 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -33,6 +33,12 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | eval b1 = b + 1 | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) - `source = table | eval b1 = lower(b) | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) +**Field-Summary** +[See additional command details](ppl-fieldsummary-command.md) +- `source = t | fieldsummary includefields=status_code nulls=false` +- `source = t | fieldsummary includefields= id, status_code, request_path nulls=true` +- `source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true` + **Nested-Fields** - `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1` - `source = catalog.table | where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield` diff --git a/docs/ppl-lang/ppl-fieldsummary-command.md b/docs/ppl-lang/ppl-fieldsummary-command.md new file mode 100644 index 000000000..468c2046b --- /dev/null +++ b/docs/ppl-lang/ppl-fieldsummary-command.md @@ -0,0 +1,83 @@ +## PPL `fieldsummary` command + +**Description** +Using `fieldsummary` command to : + - Calculate basic statistics for each field (count, distinct count, min, max, avg, stddev, mean ) + - Determine the data type of each field + +**Syntax** + +`... | fieldsummary (nulls=true/false)` + +* command accepts any preceding pipe before the terminal `fieldsummary` command and will take them into account. +* `includefields`: list of all the columns to be collected with statistics into a unified result set +* `nulls`: optional; if the true, include the null values in the aggregation calculations (replace null with zero for numeric values) + +### Example 1: + +PPL query: + + os> source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 2 | 2 | 301 | 403 | 352.0 | 352.0 | 72.12489168102785 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Example 2: + +PPL query: + + os> source = t | fieldsummary includefields= id, status_code, request_path nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "id" | 6 | 6 | 1 | 6 | 3.5 | 3.5 | 1.8708286933869707 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 4 | 3 | 200 | 403 | 184.0 | 184.0 | 161.16699413961905 | 2 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "request_path" | 2 | 2 | /about| /home | 0.0 | 0.0 | 0 | 2 |"string"| + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Additional Info +The actual query is translated into the following SQL-like statement: + +```sql + SELECT + id AS Field, + COUNT(id) AS COUNT, + COUNT(DISTINCT id) AS COUNT_DISTINCT, + MIN(id) AS MIN, + MAX(id) AS MAX, + AVG(id) AS AVG, + MEAN(id) AS MEAN, + STDDEV(id) AS STDDEV, + (COUNT(1) - COUNT(id)) AS Nulls, + TYPEOF(id) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +UNION + SELECT + status_code AS Field, + COUNT(status_code) AS COUNT, + COUNT(DISTINCT status_code) AS COUNT_DISTINCT, + MIN(status_code) AS MIN, + MAX(status_code) AS MAX, + AVG(status_code) AS AVG, + MEAN(status_code) AS MEAN, + STDDEV(status_code) AS STDDEV, + (COUNT(1) - COUNT(status_code)) AS Nulls, + TYPEOF(status_code) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +``` +For each such columns (id, status_code) there will be a unique statement and all the fields will be presented togather in the result using a UNION operator + + +### Limitation: + - `topvalues` option was removed from this command due the possible performance impact of such sub-query. As an alternative one can use the `top` command directly as shown [here](ppl-top-command.md). + diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala new file mode 100644 index 000000000..5a5990001 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -0,0 +1,751 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder, Subtract} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFieldSummaryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNullableTableHttpLog(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 184.0, 184.0, 161.16699413961905, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=false ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 276.0, 276.0, 97.1356439899038, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=true + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 184.0, 184.0, 161.16699413961905, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", 0.0, 0.0, 0.0, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=false + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 276.0, 276.0, 97.1356439899038, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", null, null, null, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index ed170449a..6138a94a2 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -86,6 +86,12 @@ STR: 'STR'; IP: 'IP'; NUM: 'NUM'; + +// FIELDSUMMARY keywords +FIELDSUMMARY: 'FIELDSUMMARY'; +INCLUDEFIELDS: 'INCLUDEFIELDS'; +NULLS: 'NULLS'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 9686b0139..ae5f14498 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -52,6 +52,7 @@ commands | lookupCommand | renameCommand | fillnullCommand + | fieldsummaryCommand ; searchCommand @@ -59,6 +60,15 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; + +fieldsummaryCommand + : FIELDSUMMARY (fieldsummaryParameter)* + ; + +fieldsummaryParameter + : INCLUDEFIELDS EQUAL fieldList # fieldsummaryIncludeFields + | NULLS EQUAL booleanLiteral # fieldsummaryNulls + ; describeCommand : DESCRIBE tableSourceClause @@ -1088,6 +1098,10 @@ keywordsCanBeId | SPARKLINE | C | DC + // FIELD SUMMARY + | FIELDSUMMARY + | INCLUDEFIELDS + | NULLS // JOIN TYPE | OUTER | INNER diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index c361ded08..5ac54127b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -16,6 +16,8 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -206,6 +208,10 @@ public T visitField(Field node, C context) { return visitChildren(node, context); } + public T visitFieldList(FieldList node, C context) { + return visitChildren(node, context); + } + public T visitQualifiedName(QualifiedName node, C context) { return visitChildren(node, context); } @@ -296,9 +302,14 @@ public T visitExplain(Explain node, C context) { public T visitInSubquery(InSubquery node, C context) { return visitChildren(node, context); } + public T visitFillNull(FillNull fillNull, C context) { return visitChildren(fillNull, context); } + + public T visitFieldSummary(FieldSummary fieldSummary, C context) { + return visitChildren(fieldSummary, context); + } public T visitScalarSubquery(ScalarSubquery node, C context) { return visitChildren(node, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java new file mode 100644 index 000000000..4f6ac5e14 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** Expression node that includes a list of fields nodes. */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@AllArgsConstructor +public class FieldList extends UnresolvedExpression { + private final List fieldList; + + @Override + public List getChild() { + return ImmutableList.copyOf(fieldList); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldList(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java new file mode 100644 index 000000000..a8072e76b --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +public class FieldSummary extends UnresolvedPlan { + private List includeFields; + private boolean includeNull; + private List collect; + private UnresolvedPlan child; + + public FieldSummary(List collect) { + this.collect = collect; + collect.forEach(exp -> { + if (exp instanceof Argument) { + this.includeNull = (boolean) ((Argument)exp).getValue().getValue(); + } + if (exp instanceof AttributeList) { + this.includeFields = ((AttributeList)exp).getAttrList(); + } + }); + } + + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldSummary(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f44fe26d8..9e1a9a743 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -164,6 +164,8 @@ public enum BuiltinFunctionName { /** Aggregation Function. */ AVG(FunctionName.of("avg")), + MEAN(FunctionName.of("mean")), + STDDEV(FunctionName.of("stddev")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), MIN(FunctionName.of("min")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 46a016d1a..61762f616 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -154,7 +154,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + + public LogicalPlan applyBranches(List> plans) { + plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); + planBranches.remove(0); + return getPlan(); + } + /** * append plan with evolving plans branches * @@ -281,4 +287,5 @@ public static Optional findRelation(LogicalPlan plan) { // Return null if no UnresolvedRelation is found return Optional.empty(); } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 902fc72e3..76a7a0c79 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -65,6 +65,7 @@ import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; @@ -85,6 +86,7 @@ import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; @@ -380,6 +382,12 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { node.getSize(), DataTypes.IntegerType), p)); } + @Override + public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { + fieldSummary.getChild().get(0).accept(this, context); + return FieldSummaryTransformer.translate(fieldSummary, context); + } + @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { fillNull.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 1c0fe919f..26a8e2278 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -415,8 +416,14 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) groupListBuilder.build()); return aggregation; } - - /** Rare command. */ + + /** Fieldsummary command. */ + @Override + public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryCommandContext ctx) { + return new FieldSummary(ctx.fieldsummaryParameter().stream().map(arg -> expressionBuilder.visit(arg)).collect(Collectors.toList())); + } + + /** Rare command. */ @Override public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 3b98edd77..ea51ca7a1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -16,11 +16,13 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; @@ -39,6 +41,7 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -50,6 +53,8 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; @@ -179,6 +184,20 @@ public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ArgumentFactory.getArgumentList(ctx)); } + @Override + public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.FieldsummaryIncludeFieldsContext ctx) { + List list = ctx.fieldList().fieldExpression().stream() + .map(this::visitFieldExpression) + .collect(Collectors.toList()); + return new AttributeList(list); + } + + @Override + public UnresolvedExpression visitFieldsummaryNulls(OpenSearchPPLParser.FieldsummaryNullsContext ctx) { + return new Argument("NULLS",(Literal)visitBooleanLiteral(ctx.booleanLiteral())); + } + + /** * Aggregation function. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 3c367a948..a01b38a80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -12,10 +12,12 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; +import java.util.Optional; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -26,31 +28,37 @@ * @return */ public interface AggregatorTranslator { - + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); + boolean distinct = aggregateFunction.getDistinct(); // Additional aggregation function operators will be added here - switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { + BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get(); + switch (functionName) { case MAX: - return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false); case MIN: - return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false); + case MEAN: + return new UnresolvedFunction(seq("MEAN"), seq(arg), distinct, empty(),false); case AVG: - return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false); case COUNT: - return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false); case SUM: - return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false); + case STDDEV: + return new UnresolvedFunction(seq("STDDEV"), seq(arg), distinct, empty(),false); case STDDEV_POP: - return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false); case STDDEV_SAMP: - return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), distinct, empty(),false); case PERCENTILE: - return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: - return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 4345b0897..62eef90ed 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -20,7 +20,10 @@ import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; +import java.util.Arrays; import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -41,6 +44,7 @@ public interface DataTypeTransformer { static Seq seq(T... elements) { return seq(List.of(elements)); } + static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java new file mode 100644 index 000000000..dd8f01874 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -0,0 +1,253 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.AliasIdentifier; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Alias; +import org.apache.spark.sql.catalyst.expressions.Alias$; +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.Subtract; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LocalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.Sort; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystPlanContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MEAN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface FieldSummaryTransformer { + + String TOP_VALUES = "TopValues"; + String NULLS = "Nulls"; + String FIELD = "Field"; + + /** + * translate the command into the aggregate statement group by the column name + */ + static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { + List> aggBranches = fieldSummary.getIncludeFields().stream() + .filter(field -> field instanceof org.opensearch.sql.ast.expression.Field ) + .map(field -> { + Literal fieldNameLiteral = Literal.create(((org.opensearch.sql.ast.expression.Field)field).getField().toString(), StringType); + UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(((org.opensearch.sql.ast.expression.Field)field).getField().getParts())); + context.withProjectedFields(Collections.singletonList(field)); + + // Alias for the field name as Field + Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, + FIELD, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(field) as Count + UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(count, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(DISTINCT field) as CountDistinct + UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); + Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, + "DISTINCT", + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MAX(field) as MAX + UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); + Alias maxAlias = Alias$.MODULE$.apply(max, + MAX.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MIN(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); + Alias minAlias = Alias$.MODULE$.apply(min, + MIN.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the AVG(field) as Avg + Alias avgAlias = getAggMethodAlias(AVG, fieldSummary, fieldLiteral); + + //Alias for the MEAN(field) as Mean + Alias meanAlias = getAggMethodAlias(MEAN, fieldSummary, fieldLiteral); + + //Alias for the STDDEV(field) as Stddev + Alias stddevAlias = getAggMethodAlias(STDDEV, fieldSummary, fieldLiteral); + + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); + Alias nonNullAlias = Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + + //Alias for the typeOf(field) as Type + UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); + Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, + TYPEOF.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Aggregation + return (Function) p -> + new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, meanAlias, stddevAlias, nonNullAlias, typeOfAlias), p); + }).collect(Collectors.toList()); + + return context.applyBranches(aggBranches); + } + + /** + * Alias for aggregate function (if isIncludeNull use COALESCE to replace nulls with zeros) + */ + private static Alias getAggMethodAlias(BuiltinFunctionName method, FieldSummary fieldSummary, UnresolvedAttribute fieldLiteral) { + UnresolvedFunction avg = new UnresolvedFunction(seq(method.name()), seq(fieldLiteral), false, empty(), false); + Alias avgAlias = Alias$.MODULE$.apply(avg, + method.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + if (fieldSummary.isIncludeNull()) { + UnresolvedFunction coalesceExpr = new UnresolvedFunction( + seq("COALESCE"), + seq(fieldLiteral, Literal.create(0, DataTypes.IntegerType)), + false, + empty(), + false + ); + avg = new UnresolvedFunction(seq(method.name()), seq(coalesceExpr), false, empty(), false); + avgAlias = Alias$.MODULE$.apply(avg, + method.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + } + return avgAlias; + } + + /** + * top values sub-query + */ + private static Alias topValuesSubQueryAlias(FieldSummary fieldSummary, CatalystPlanContext context, UnresolvedAttribute fieldLiteral, UnresolvedFunction count) { + int topValues = 5;// this value should come from the FieldSummary definition + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + !fieldSummary.isIncludeNull() + ); + Alias topValuesAlias = Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + Project subQueryProject = new Project(seq(topValuesAlias), buildTopValueSubQuery(topValues, fieldLiteral, context)); + ScalarSubquery scalarSubquery = ScalarSubquery$.MODULE$.apply( + subQueryProject, + seq(new ArrayList()), + NamedExpression.newExprId(), + seq(new ArrayList()), + empty(), + empty()); + + return Alias$.MODULE$.apply( + scalarSubquery, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + } + + /** + * inner top values query + * ----------------------------------------------------- + * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, count_status)), None)] + * : : +- 'GlobalLimit 5 + * : : +- 'LocalLimit 5 + * : : +- 'Sort ['count_status DESC NULLS LAST], true + * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] + * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + private static LogicalPlan buildTopValueSubQuery(int topValues, UnresolvedAttribute fieldLiteral, CatalystPlanContext context) { + //Alias for the count(field) as Count + UnresolvedFunction countFunc = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(countFunc, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + Aggregate aggregate = new Aggregate(seq(fieldLiteral), seq(countAlias), context.getPlan()); + Project project = new Project(seq(fieldLiteral, countAlias), aggregate); + SortOrder sortOrder = new SortOrder(countAlias, Descending$.MODULE$, Ascending$.MODULE$.defaultNullOrdering(), seq()); + Sort sort = new Sort(seq(sortOrder), true, project); + GlobalLimit limit = new GlobalLimit(Literal.create(topValues, IntegerType), new LocalLimit(Literal.create(topValues, IntegerType), sort)); + return new SubqueryAlias(new AliasIdentifier(TOP_VALUES + "_subquery"), limit); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala new file mode 100644 index 000000000..c14e1f6cf --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala @@ -0,0 +1,709 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not, Subtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project, Union} + +class PPLLogicalPlanFieldSummaryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test fieldsummary with single field includefields(status_code) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } +}