diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala index 1a6a1006d..50f325000 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -8,7 +8,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -85,22 +85,57 @@ class FlintSparkPPLFieldSummaryITSuite comparePlans(expectedPlan, logicalPlan, false) } - /** - * // val frame = sql(s""" // | SELECT // | 'status_code' AS Field, // | COUNT(status_code) AS - * Count, // | COUNT(DISTINCT status_code) AS Distinct, // | MIN(status_code) AS Min, // | - * MAX(status_code) AS Max, // | AVG(CAST(status_code AS DOUBLE)) AS Avg, // | - * typeof(status_code) AS Type, // | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) // - * \| FROM ( // | SELECT status_code, COUNT(*) AS count_status // | FROM $testTable // | GROUP - * BY status_code // | ORDER BY count_status DESC // | LIMIT 5 // | )) AS top_values, // | - * COUNT(*) - COUNT(status_code) AS Nulls // | FROM $testTable // | GROUP BY typeof(status_code) - * // | // | UNION ALL // | // | SELECT // | 'id' AS Field, // | COUNT(id) AS Count, // | - * COUNT(DISTINCT id) AS Distinct, // | MIN(id) AS Min, // | MAX(id) AS Max, // | AVG(CAST(id AS - * DOUBLE)) AS Avg, // | typeof(id) AS Type, // | (SELECT COLLECT_LIST(STRUCT(id, count_id)) // - * \| FROM ( // | SELECT id, COUNT(*) AS count_id // | FROM $testTable // | GROUP BY id // | - * ORDER BY count_id DESC // | LIMIT 5 // | )) AS top_values, // | COUNT(*) - COUNT(id) AS Nulls - * // | FROM $testTable // | GROUP BY typeof(id) // |""".stripMargin) // Aggregate with - * functions applied to status_code - */ + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + test( "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { val frame = sql(s""" diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala index 5d376e18b..ed0f078c0 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite.scala @@ -12,9 +12,9 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Filter, Project, Union} class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite extends SparkFunSuite @@ -69,6 +69,55 @@ class PPLLogicalPlanFieldSummaryCommandTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "COUNT_DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test( "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { val context = new CatalystPlanContext