From ed1b24c9391e2dda8a80483cf7d4927710240dee Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Fri, 25 Oct 2024 10:07:37 -0700 Subject: [PATCH 1/5] Support alter refresh interval on external scheduler (#801) * Support alter refresh interval on external scheduler Signed-off-by: Louis Chu * Add more ITs Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .../opensearch/flint/spark/FlintSpark.scala | 34 ++-- .../flint/spark/FlintSparkIndexOptions.scala | 1 - .../spark/FlintSparkIndexBuilderSuite.scala | 8 +- .../FlintSparkMaterializedViewITSuite.scala | 1 - .../spark/FlintSparkUpdateIndexITSuite.scala | 146 +++++++++++++++++- 5 files changed, 171 insertions(+), 19 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index e5805731b..532bd8e60 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -23,6 +23,7 @@ import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode._ +import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.opensearch.flint.spark.scheduler.{AsyncQuerySchedulerBuilder, FlintSparkJobExternalSchedulingService, FlintSparkJobInternalSchedulingService, FlintSparkJobSchedulingService} import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder.AsyncQuerySchedulerAction import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex @@ -229,16 +230,16 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val originalOptions = describeIndex(indexName) .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) .options - validateUpdateAllowed(originalOptions, index.options) - val isSchedulerModeChanged = - index.options.isExternalSchedulerEnabled() != originalOptions.isExternalSchedulerEnabled() + validateUpdateAllowed(originalOptions, index.options) withTransaction[Option[String]](indexName, "Update Flint index") { tx => // Relies on validation to prevent: // 1. auto-to-auto updates besides scheduler_mode // 2. any manual-to-manual updates // 3. both refresh_mode and scheduler_mode updated - (index.options.autoRefresh(), isSchedulerModeChanged) match { + ( + index.options.autoRefresh(), + isSchedulerModeChanged(originalOptions, index.options)) match { case (true, true) => updateSchedulerMode(index, tx) case (true, false) => updateIndexManualToAuto(index, tx) case (false, false) => updateIndexAutoToManual(index, tx) @@ -478,11 +479,17 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w "Altering index to full/incremental refresh") case (false, true) => - // original refresh_mode is auto, only allow changing scheduler_mode - validateChangedOptions( - changedOptions, - Set(SCHEDULER_MODE), - "Altering index when auto_refresh remains true") + // original refresh_mode is auto, only allow changing scheduler_mode and potentially refresh_interval + var allowedOptions = Set(SCHEDULER_MODE) + val schedulerMode = + if (updatedOptions.isExternalSchedulerEnabled()) SchedulerMode.EXTERNAL + else SchedulerMode.INTERNAL + val contextPrefix = + s"Altering index when auto_refresh remains true and scheduler_mode is $schedulerMode" + if (updatedOptions.isExternalSchedulerEnabled()) { + allowedOptions += REFRESH_INTERVAL + } + validateChangedOptions(changedOptions, allowedOptions, contextPrefix) case (false, false) => // original refresh_mode is full/incremental, not allowed to change any options @@ -507,6 +514,12 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w } } + private def isSchedulerModeChanged( + originalOptions: FlintSparkIndexOptions, + updatedOptions: FlintSparkIndexOptions): Boolean = { + updatedOptions.isExternalSchedulerEnabled() != originalOptions.isExternalSchedulerEnabled() + } + private def updateIndexAutoToManual( index: FlintSparkIndex, tx: OptimisticTransaction[Option[String]]): Option[String] = { @@ -587,7 +600,8 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) logInfo("Update index options complete") oldService.handleJob(index, AsyncQuerySchedulerAction.UNSCHEDULE) - logInfo(s"Unscheduled ${if (isExternal) "internal" else "external"} jobs") + logInfo( + s"Unscheduled refresh jobs from ${if (isExternal) "internal" else "external"} scheduler") newService.handleJob(index, AsyncQuerySchedulerAction.UPDATE) }) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index fc1a611ef..9b58a696c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -10,7 +10,6 @@ import java.util.{Collections, UUID} import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization -import org.opensearch.flint.core.logging.CustomLogging.logInfo import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala index 4ff5b5adb..80b788253 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala @@ -210,12 +210,12 @@ class FlintSparkIndexBuilderSuite Some( "spark.flint.job.externalScheduler.enabled is false but scheduler_mode is set to external")), ( - "set external mode when interval above threshold and no mode specified", + "set external mode when interval below threshold and no mode specified", true, "5 minutes", - Map("auto_refresh" -> "true", "refresh_interval" -> "10 minutes"), - Some(SchedulerMode.EXTERNAL.toString), - Some("10 minutes"), + Map("auto_refresh" -> "true", "refresh_interval" -> "1 minutes"), + Some(SchedulerMode.INTERNAL.toString), + Some("1 minutes"), None), ( "throw exception when interval below threshold but mode is external", diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index c2f0f9101..14d41c2bb 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -17,7 +17,6 @@ import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName -import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.CHECKPOINT_LOCATION import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.scalatest.matchers.must.Matchers.defined diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index 53889045f..a6f7e0ed0 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -5,10 +5,16 @@ package org.opensearch.flint.spark +import scala.jdk.CollectionConverters.mapAsJavaMapConverter + import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ +import org.opensearch.OpenSearchException +import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions -import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} +import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.opensearch.index.query.QueryBuilders import org.opensearch.index.reindex.DeleteByQueryRequest @@ -180,6 +186,96 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { } } + test("update auto refresh index to switch scheduler mode") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, "true") + + withTempDir { checkpointDir => + // Create auto refresh Flint index + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options( + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "4 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath)), + testIndex) + .create() + flint.refreshIndex(testIndex) + + val indexInitial = flint.describeIndex(testIndex).get + indexInitial.options.refreshInterval() shouldBe Some("4 Minute") + the[OpenSearchException] thrownBy { + val client = + OpenSearchClientUtils.createClient(new FlintOptions(openSearchOptions.asJava)) + client.get( + new GetRequest(OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, testIndex), + RequestOptions.DEFAULT) + } + + // Update Flint index to change refresh interval + val updatedIndex = flint + .skippingIndex() + .copyWithUpdate( + indexInitial, + FlintSparkIndexOptions( + Map("scheduler_mode" -> "external", "refresh_interval" -> "5 Minutes"))) + flint.updateIndex(updatedIndex) + + // Verify index after update + val indexFinal = flint.describeIndex(testIndex).get + indexFinal.options.autoRefresh() shouldBe true + indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) + + // Verify scheduler index is updated + verifySchedulerIndex(testIndex, 5, "MINUTES") + } + } + + test("update auto refresh index to change refresh interval") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, "true") + + withTempDir { checkpointDir => + // Create auto refresh Flint index + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options( + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "10 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath)), + testIndex) + .create() + + val indexInitial = flint.describeIndex(testIndex).get + indexInitial.options.refreshInterval() shouldBe Some("10 Minute") + verifySchedulerIndex(testIndex, 10, "MINUTES") + + // Update Flint index to change refresh interval + val updatedIndex = flint + .skippingIndex() + .copyWithUpdate( + indexInitial, + FlintSparkIndexOptions(Map("refresh_interval" -> "5 Minutes"))) + flint.updateIndex(updatedIndex) + + // Verify index after update + val indexFinal = flint.describeIndex(testIndex).get + indexFinal.options.autoRefresh() shouldBe true + indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) + + // Verify scheduler index is updated + verifySchedulerIndex(testIndex, 5, "MINUTES") + } + } + // Test update options validation failure with external scheduler Seq( ( @@ -207,12 +303,32 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { (Map.empty[String, String], Map("checkpoint_location" -> "s3a://test/"))), "No options can be updated when auto_refresh remains false"), ( - "update other index option besides scheduler_mode when auto_refresh is true", + "update index option when refresh_interval value belows threshold", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("refresh_interval" -> "4 minutes"))), + "Input refresh_interval is 4 minutes, required above the interval threshold of external scheduler: 5 minutes"), + ( + "update index option when no change on auto_refresh", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("scheduler_mode" -> "internal", "refresh_interval" -> "4 minutes")), + ( + Map( + "auto_refresh" -> "true", + "scheduler_mode" -> "internal", + "checkpoint_location" -> "s3a://test/"), + Map("refresh_interval" -> "4 minutes"))), + "Altering index when auto_refresh remains true and scheduler_mode is internal only allows changing: Set(scheduler_mode). Invalid options"), + ( + "update other index option besides scheduler_mode and refresh_interval when auto_refresh is true", Seq( ( Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), Map("watermark_delay" -> "1 Minute"))), - "Altering index when auto_refresh remains true only allows changing: Set(scheduler_mode). Invalid options"), + "Altering index when auto_refresh remains true and scheduler_mode is external only allows changing: Set(scheduler_mode, refresh_interval). Invalid options"), ( "convert to full refresh with disallowed options", Seq( @@ -655,4 +771,28 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { flint.queryIndex(testIndex).collect().toSet should have size 1 } } + + private def verifySchedulerIndex( + indexName: String, + expectedPeriod: Int, + expectedUnit: String): Unit = { + val client = OpenSearchClientUtils.createClient(new FlintOptions(openSearchOptions.asJava)) + val response = client.get( + new GetRequest(OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, indexName), + RequestOptions.DEFAULT) + + response.isExists shouldBe true + val sourceMap = response.getSourceAsMap + + sourceMap.get("jobId") shouldBe indexName + sourceMap.get( + "scheduledQuery") shouldBe s"REFRESH SKIPPING INDEX ON spark_catalog.default.`test`" + sourceMap.get("enabled") shouldBe true + sourceMap.get("queryLang") shouldBe "sql" + + val schedule = sourceMap.get("schedule").asInstanceOf[java.util.Map[String, Any]] + val interval = schedule.get("interval").asInstanceOf[java.util.Map[String, Any]] + interval.get("period") shouldBe expectedPeriod + interval.get("unit") shouldBe expectedUnit + } } From ee75048ea44503475ff9992fbfe416c4f9f31a84 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Sat, 26 Oct 2024 01:24:45 +0800 Subject: [PATCH 2/5] Support PPL `JSON` functions: construction and extraction (#780) * first commit Signed-off-by: Lantao Jin * add docs and fix IT Signed-off-by: Lantao Jin * add examples for json_extract() Signed-off-by: Lantao Jin * fix missing import and doc link Signed-off-by: Lantao Jin * minor Signed-off-by: Lantao Jin * add UT and optimize the doc Signed-off-by: Lantao Jin * typo Signed-off-by: Lantao Jin * fix the issue when merge conflicts Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- docs/ppl-lang/README.md | 2 + docs/ppl-lang/functions/ppl-json.md | 237 +++++++++++ .../flint/spark/FlintSparkSuite.scala | 27 ++ .../flint/spark/ppl/FlintPPLSuite.scala | 14 +- .../FlintSparkPPLJsonFunctionITSuite.scala | 386 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 21 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 27 ++ .../function/BuiltinFunctionName.java | 22 + .../ppl/utils/BuiltinFunctionTranslator.java | 107 ++++- ...PlanJsonFunctionsTranslatorTestSuite.scala | 233 +++++++++++ 10 files changed, 1059 insertions(+), 17 deletions(-) create mode 100644 docs/ppl-lang/functions/ppl-json.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index fd7c36605..4fa9d10cc 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -77,6 +77,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`String Functions`](functions/ppl-string.md) + - [`JSON Functions`](functions/ppl-json.md) + - [`Condition Functions`](functions/ppl-condition.md) - [`Type Conversion Functions`](functions/ppl-conversion.md) diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md new file mode 100644 index 000000000..1953e8c70 --- /dev/null +++ b/docs/ppl-lang/functions/ppl-json.md @@ -0,0 +1,237 @@ +## PPL JSON Functions + +### `JSON` + +**Description** + +`json(value)` Evaluates whether a value can be parsed as JSON. Returns the json string if valid, null otherwise. + +**Argument type:** STRING/JSON_ARRAY/JSON_OBJECT + +**Return type:** STRING + +A STRING expression of a valid JSON object format. + +Example: + + os> source=people | eval `valid_json()` = json('[1,2,3,{"f1":1,"f2":[5,6]},4]') | fields valid_json + fetched rows / total rows = 1/1 + +---------------------------------+ + | valid_json | + +---------------------------------+ + | [1,2,3,{"f1":1,"f2":[5,6]},4] | + +---------------------------------+ + + os> source=people | eval `invalid_json()` = json('{"invalid": "json"') | fields invalid_json + fetched rows / total rows = 1/1 + +----------------+ + | invalid_json | + +----------------+ + | null | + +----------------+ + + +### `JSON_OBJECT` + +**Description** + +`json_object(, [, , ]...)` returns a JSON object from members of key-value pairs. + +**Argument type:** +- A \ must be STRING. +- A \ can be any data types. + +**Return type:** JSON_OBJECT (Spark StructType) + +A StructType expression of a valid JSON object. + +Example: + + os> source=people | eval result = json(json_object('key', 123.45)) | fields result + fetched rows / total rows = 1/1 + +------------------+ + | result | + +------------------+ + | {"key":123.45} | + +------------------+ + + os> source=people | eval result = json(json_object('outer', json_object('inner', 123.45))) | fields result + fetched rows / total rows = 1/1 + +------------------------------+ + | result | + +------------------------------+ + | {"outer":{"inner":123.45}} | + +------------------------------+ + + +### `JSON_ARRAY` + +**Description** + +`json_array(...)` Creates a JSON ARRAY using a list of values. + +**Argument type:** +- A \ can be any kind of value such as string, number, or boolean. + +**Return type:** ARRAY (Spark ArrayType) + +An array of any supported data type for a valid JSON array. + +Example: + + os> source=people | eval `json_array` = json_array(1, 2, 0, -1, 1.1, -0.11) + fetched rows / total rows = 1/1 + +----------------------------+ + | json_array | + +----------------------------+ + | 1.0,2.0,0.0,-1.0,1.1,-0.11 | + +----------------------------+ + + os> source=people | eval `json_array_object` = json(json_object("array", json_array(1, 2, 0, -1, 1.1, -0.11))) + fetched rows / total rows = 1/1 + +----------------------------------------+ + | json_array_object | + +----------------------------------------+ + | {"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]} | + +----------------------------------------+ + +### `JSON_ARRAY_LENGTH` + +**Description** + +`json_array_length(jsonArray)` Returns the number of elements in the outermost JSON array. + +**Argument type:** STRING/JSON_ARRAY + +A STRING expression of a valid JSON array format, or JSON_ARRAY object. + +**Return type:** INTEGER + +`NULL` is returned in case of any other valid JSON string, `NULL` or an invalid JSON. + +Example: + + os> source=people | eval `lenght1` = json_array_length('[1,2,3,4]'), `lenght2` = json_array_length('[1,2,3,{"f1":1,"f2":[5,6]},4]'), `not_array` = json_array_length('{"key": 1}') + fetched rows / total rows = 1/1 + +-----------+-----------+-------------+ + | lenght1 | lenght2 | not_array | + +-----------+-----------+-------------+ + | 4 | 5 | null | + +-----------+-----------+-------------+ + + os> source=people | eval `json_array` = json_array_length(json_array(1,2,3,4)), `empty_array` = json_array_length(json_array()) + fetched rows / total rows = 1/1 + +--------------+---------------+ + | json_array | empty_array | + +--------------+---------------+ + | 4 | 0 | + +--------------+---------------+ + +### `JSON_EXTRACT` + +**Description** + +`json_extract(jsonStr, path)` Extracts json object from a json string based on json path specified. Return null if the input json string is invalid. + +**Argument type:** STRING, STRING + +**Return type:** STRING + +A STRING expression of a valid JSON object format. + +`NULL` is returned in case of an invalid JSON. + +Example: + + os> source=people | eval `json_extract('{"a":"b"}', '$.a')` = json_extract('{"a":"b"}', '$a') + fetched rows / total rows = 1/1 + +----------------------------------+ + | json_extract('{"a":"b"}', 'a') | + +----------------------------------+ + | b | + +----------------------------------+ + + os> source=people | eval `json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[1].b')` = json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[1].b') + fetched rows / total rows = 1/1 + +-----------------------------------------------------------+ + | json_extract('{"a":[{"b":1.0},{"b":2.0}]}', '$.a[1].b') | + +-----------------------------------------------------------+ + | 2.0 | + +-----------------------------------------------------------+ + + os> source=people | eval `json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[*].b')` = json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[*].b') + fetched rows / total rows = 1/1 + +-----------------------------------------------------------+ + | json_extract('{"a":[{"b":1.0},{"b":2.0}]}', '$.a[*].b') | + +-----------------------------------------------------------+ + | [1.0,2.0] | + +-----------------------------------------------------------+ + + os> source=people | eval `invalid_json` = json_extract('{"invalid": "json"') + fetched rows / total rows = 1/1 + +----------------+ + | invalid_json | + +----------------+ + | null | + +----------------+ + + +### `JSON_KEYS` + +**Description** + +`json_keys(jsonStr)` Returns all the keys of the outermost JSON object as an array. + +**Argument type:** STRING + +A STRING expression of a valid JSON object format. + +**Return type:** ARRAY[STRING] + +`NULL` is returned in case of any other valid JSON string, or an empty string, or an invalid JSON. + +Example: + + os> source=people | eval `keys` = json_keys('{"f1":"abc","f2":{"f3":"a","f4":"b"}}') + fetched rows / total rows = 1/1 + +------------+ + | keus | + +------------+ + | [f1, f2] | + +------------+ + + os> source=people | eval `keys` = json_keys('[1,2,3,{"f1":1,"f2":[5,6]},4]') + fetched rows / total rows = 1/1 + +--------+ + | keys | + +--------+ + | null | + +--------+ + +### `JSON_VALID` + +**Description** + +`json_valid(jsonStr)` Evaluates whether a JSON string uses valid JSON syntax and returns TRUE or FALSE. + +**Argument type:** STRING + +**Return type:** BOOLEAN + +Example: + + os> source=people | eval `valid_json` = json_valid('[1,2,3,4]'), `invalid_json` = json_valid('{"invalid": "json"') | feilds `valid_json`, `invalid_json` + fetched rows / total rows = 1/1 + +--------------+----------------+ + | valid_json | invalid_json | + +--------------+----------------+ + | True | False | + +--------------+----------------+ + + os> source=accounts | where json_valid('[1,2,3,4]') and isnull(email) | fields account_number, email + fetched rows / total rows = 1/1 + +------------------+---------+ + | account_number | email | + |------------------+---------| + | 13 | null | + +------------------+---------+ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 1ecf48d28..23a336b4c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -642,4 +642,31 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (6, 403, '/home', null) | """.stripMargin) } + + protected def createNullableJsonContentTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | jString STRING, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, '{"account_number":1,"balance":39225,"age":32,"gender":"M"}', true), + | (2, '{"f1":"abc","f2":{"f3":"a","f4":"b"}}', true), + | (3, '[1,2,3,{"f1":1,"f2":[5,6]},4]', true), + | (4, '[]', true), + | (5, '{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}', true), + | (6, '[1,2,3]', true), + | (7, '[1,2', false), + | (8, '[invalid json]', false), + | (9, '{"invalid": "json"', false), + | (10, 'invalid json', false), + | (11, null, false) + | """.stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala index 26940020f..465ce7d12 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala @@ -9,11 +9,7 @@ import org.opensearch.flint.spark.{FlintPPLSparkExtensions, FlintSparkExtensions import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.flint.config.FlintSparkConf.OPTIMIZER_RULE_ENABLED -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession trait FlintPPLSuite extends FlintSparkSuite { override protected def sparkConf: SparkConf = { @@ -29,11 +25,11 @@ trait FlintPPLSuite extends FlintSparkSuite { def assertSameRows(expected: Seq[Row], df: DataFrame): Unit = { QueryTest.sameRows(expected, df.collect().toSeq).foreach { results => fail(s""" - |Results do not match for query: - |${df.queryExecution} - |== Results == - |$results - """.stripMargin) + |Results do not match for query: + |${df.queryExecution} + |== Results == + |$results + """.stripMargin) } } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala new file mode 100644 index 000000000..7cc0a221d --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -0,0 +1,386 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, Not} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLJsonFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + private val validJson1 = "{\"account_number\":1,\"balance\":39225,\"age\":32,\"gender\":\"M\"}" + private val validJson2 = "{\"f1\":\"abc\",\"f2\":{\"f3\":\"a\",\"f4\":\"b\"}}" + private val validJson3 = "[1,2,3,{\"f1\":1,\"f2\":[5,6]},4]" + private val validJson4 = "[]" + private val validJson5 = + "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}" + private val validJson6 = "[1,2,3]" + private val invalidJson1 = "[1,2" + private val invalidJson2 = "[invalid json]" + private val invalidJson3 = "{\"invalid\": \"json\"" + private val invalidJson4 = "invalid json" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test json() function: valid JSON") { + Seq(validJson1, validJson2, validJson3, validJson4, validJson5).foreach { jsonStr => + val frame = sql(s""" + | source = $testTable + | | eval result = json('$jsonStr') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(jsonStr)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(Literal(jsonStr), Literal("$")), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("test json() function: invalid JSON") { + Seq(invalidJson1, invalidJson2, invalidJson3, invalidJson4).foreach { jsonStr => + val frame = sql(s""" + | source = $testTable + | | eval result = json('$jsonStr') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(Literal(jsonStr), Literal("$")), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("test json() function on field") { + val frame = sql(s""" + | source = $testTable + | | where isValid = true | eval result = json(jString) | fields result + | """.stripMargin) + assertSameRows( + Seq(validJson1, validJson2, validJson3, validJson4, validJson5, validJson6).map( + Row.apply(_)), + frame) + + val frame2 = sql(s""" + | source = $testTable + | | where isValid = false | eval result = json(jString) | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null), Row(null), Row(null), Row(null), Row(null)), frame2) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_array()") { + // test string array + var frame = sql(s""" + | source = $testTable | eval result = json_array('this', 'is', 'a', 'string', 'array') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq("this", "is", "a", "string", "array").toArray)), frame) + + // test empty array + frame = sql(s""" + | source = $testTable | eval result = json_array() | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Array.empty)), frame) + + // test number array + frame = sql(s""" + | source = $testTable | eval result = json_array(1, 2, 0, -1, 1.1, -0.11) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(Seq(1.0, 2.0, 0.0, -1.0, 1.1, -0.11).toArray)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(0), Literal(-1), Literal(1.1), Literal(-0.11)), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + // item in json_array should all be the same type + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval result = json_array('this', 'is', 1.1, -0.11, true, false) | head 1 | fields result + | """.stripMargin)) + assert(ex.getMessage().contains("should all be the same type")) + } + + test("test json_array() with json()") { + val frame = sql(s""" + | source = $testTable | eval result = json(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""[1.0,2.0,0.0,-1.0,1.1,-0.11]""")), frame) + } + + test("test json_array_length()") { + var frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(5)), frame) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(6)), frame) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(0)), frame) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(0)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2,3,4]') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(4)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2,3,{"f1":1,"f2":[5,6]},4]') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(5)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('{\"key\": 1}') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null)), frame) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2') | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row(null)), frame) + } + + test("test json_object()") { + // test value is a string + var frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', 'string_value')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":"string_value"}""")), frame) + + // test value is a number + frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', 123.45)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":123.45}""")), frame) + + // test value is a boolean + frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', true)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":true}""")), frame) + + frame = sql(s""" + | source = $testTable| eval result = json(json_object("a", 1, "b", 2, "c", 3)) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"a":1,"b":2,"c":3}""")), frame) + } + + test("test json_object() and json_array()") { + // test value is an empty array + var frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', array())) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":[]}""")), frame) + + // test value is an array + frame = sql(s""" + | source = $testTable| eval result = json(json_object('key', array(1, 2, 3))) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"key":[1,2,3]}""")), frame) + + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false), + "result")() + var expectedPlan = Project( + Seq(UnresolvedAttribute("result")), + GlobalLimit( + Literal(1), + LocalLimit(Literal(1), Project(Seq(UnresolvedStar(None), jsonFunc), table)))) + comparePlans(frame.queryExecution.logical, expectedPlan, checkAnalysis = false) + } + + test("test json_object() nested") { + val frame = sql(s""" + | source = $testTable | eval result = json(json_object('outer', json_object('inner', 123.45))) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"outer":{"inner":123.45}}""")), frame) + } + + test("test json_object(), json_array() and json()") { + val frame = sql(s""" + | source = $testTable | eval result = json(json_object("array", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("""{"array":[1.0,2.0,0.0,-1.0,1.1,-0.11]}""")), frame) + } + + test("test json_valid()") { + val frame = sql(s""" + | source = $testTable + | | where json_valid(jString) | fields jString + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Seq(validJson1, validJson2, validJson3, validJson4, validJson5, validJson6) + .map(Row.apply(_)) + .toArray + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val frame2 = sql(s""" + | source = $testTable + | | where not json_valid(jString) | fields jString + | """.stripMargin) + val results2: Array[Row] = frame2.collect() + val expectedResults2: Array[Row] = + Seq(invalidJson1, invalidJson2, invalidJson3, invalidJson4, null).map(Row.apply(_)).toArray + assert(results2.sameElements(expectedResults2)) + + val logicalPlan: LogicalPlan = frame2.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = + UnresolvedFunction( + "isnotnull", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false) + val where = Filter(Not(jsonFunc), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("jString")), where) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_keys()") { + val frame = sql(s""" + | source = $testTable + | | where isValid = true + | | eval result = json_keys(json(jString)) | fields result + | """.stripMargin) + val expectedRows = Seq( + Row(Array("account_number", "balance", "age", "gender")), + Row(Array("f1", "f2")), + Row(null), + Row(null), + Row(Array("teacher", "student")), + Row(null)) + assertSameRows(expectedRows, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "json_object_keys", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_extract()") { + val frame = sql(""" + | source = spark_catalog.default.flint_ppl_test | where id = 5 + | | eval root = json_extract(jString, '$') + | | eval teacher = json_extract(jString, '$.teacher') + | | eval students = json_extract(jString, '$.student') + | | eval students_* = json_extract(jString, '$.student[*]') + | | eval student_0 = json_extract(jString, '$.student[0]') + | | eval student_names = json_extract(jString, '$.student[*].name') + | | eval student_1_name = json_extract(jString, '$.student[1].name') + | | eval student_non_exist_key = json_extract(jString, '$.student[0].non_exist_key') + | | eval student_non_exist = json_extract(jString, '$.student[10]') + | | fields root, teacher, students, students_*, student_0, student_names, student_1_name, student_non_exist_key, student_non_exist + | """.stripMargin) + val expectedSeq = Seq( + Row( + """{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}""", + "Alice", + """[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]""", + """[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]""", + """{"name":"Bob","rank":1}""", + """["Bob","Charlie"]""", + "Charlie", + null, + null)) + assertSameRows(expectedSeq, frame) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 4494ee72b..ed170449a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -363,6 +363,27 @@ CAST: 'CAST'; ISEMPTY: 'ISEMPTY'; ISBLANK: 'ISBLANK'; +// JSON TEXT FUNCTIONS +JSON: 'JSON'; +JSON_OBJECT: 'JSON_OBJECT'; +JSON_ARRAY: 'JSON_ARRAY'; +JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +JSON_EXTRACT: 'JSON_EXTRACT'; +JSON_KEYS: 'JSON_KEYS'; +JSON_VALID: 'JSON_VALID'; +//JSON_APPEND: 'JSON_APPEND'; +//JSON_DELETE: 'JSON_DELETE'; +//JSON_EXTEND: 'JSON_EXTEND'; +//JSON_SET: 'JSON_SET'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_FILTER'; +//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; +//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; + +// COLLECTION FUNCTIONS +ARRAY: 'ARRAY'; + // BOOL FUNCTIONS LIKE: 'LIKE'; ISNULL: 'ISNULL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 064688983..9686b0139 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -509,6 +509,8 @@ evalFunctionName | positionFunctionName | coalesceFunctionName | cryptographicFunctionName + | jsonFunctionName + | collectionFunctionName ; functionArgs @@ -763,6 +765,7 @@ conditionFunctionBase | IFNULL | NULLIF | ISPRESENT + | JSON_VALID ; systemFunctionName @@ -791,6 +794,29 @@ textFunctionName | ISBLANK ; +jsonFunctionName + : JSON + | JSON_OBJECT + | JSON_ARRAY + | JSON_ARRAY_LENGTH + | JSON_EXTRACT + | JSON_KEYS + | JSON_VALID +// | JSON_APPEND +// | JSON_DELETE +// | JSON_EXTEND +// | JSON_SET +// | JSON_ARRAY_ALL_MATCH +// | JSON_ARRAY_ANY_MATCH +// | JSON_ARRAY_FILTER +// | JSON_ARRAY_MAP +// | JSON_ARRAY_REDUCE + ; + +collectionFunctionName + : ARRAY + ; + positionFunctionName : POSITION ; @@ -959,6 +985,7 @@ keywordsCanBeId | intervalUnit | dateTimeFunctionName | textFunctionName + | jsonFunctionName | mathematicalFunctionName | positionFunctionName | cryptographicFunctionName diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 1b41a3df8..f44fe26d8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -204,6 +204,28 @@ public enum BuiltinFunctionName { TRIM(FunctionName.of("trim")), UPPER(FunctionName.of("upper")), + /** JSON Functions. */ + // If the function argument is a valid JSON, return itself, or return NULL + JSON(FunctionName.of("json")), + JSON_OBJECT(FunctionName.of("json_object")), + JSON_ARRAY(FunctionName.of("json_array")), + JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + JSON_EXTRACT(FunctionName.of("json_extract")), + JSON_KEYS(FunctionName.of("json_keys")), + JSON_VALID(FunctionName.of("json_valid")), +// JSON_DELETE(FunctionName.of("json_delete")), +// JSON_APPEND(FunctionName.of("json_append")), +// JSON_EXTEND(FunctionName.of("json_extend")), +// JSON_SET(FunctionName.of("json_set")), +// JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")), +// JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")), +// JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")), +// JSON_ARRAY_MAP(FunctionName.of("json_array_map")), +// JSON_ARRAY_REDUCE(FunctionName.of("json_array_reduce")), + + /** COLLECTION Functions **/ + ARRAY(FunctionName.of("array")), + /** NULL Test. */ IS_NULL(FunctionName.of("is null")), IS_NOT_NULL(FunctionName.of("is not null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java index 485ccb522..8982fe859 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -7,13 +7,46 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal$; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; import java.util.Map; +import java.util.function.Function; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.*; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_LENGTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_EXTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_KEYS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_OBJECT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_VALID; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MODULUS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.HOUR_OF_DAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LENGTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIME; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_HOUR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTH_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND_OF_MINUTE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK_OF_YEAR; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -21,9 +54,11 @@ public interface BuiltinFunctionTranslator { /** * The name mapping between PPL builtin functions to Spark builtin functions. + * This is only used for the built-in functions between PPL and Spark with different names. + * If the built-in function names are the same in PPL and Spark, add it to {@link BuiltinFunctionName} only. */ static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING - = new ImmutableMap.Builder() + = ImmutableMap.builder() // arithmetic operators .put(ADD, "+") .put(SUBTRACT, "-") @@ -45,10 +80,6 @@ public interface BuiltinFunctionTranslator { .put(DATEDIFF, "datediff") .put(LOCALTIME, "localtimestamp") .put(SYSDATE, "now") - // Cryptographic functions - .put(MD5, "md5") - .put(SHA1, "sha1") - .put(SHA2, "sha2") // condition functions .put(IS_NULL, "isnull") .put(IS_NOT_NULL, "isnotnull") @@ -56,8 +87,60 @@ public interface BuiltinFunctionTranslator { .put(COALESCE, "coalesce") .put(LENGTH, "length") .put(TRIM, "trim") + // json functions + .put(JSON_KEYS, "json_object_keys") + .put(JSON_EXTRACT, "get_json_object") .build(); + /** + * The name mapping between PPL builtin functions to Spark builtin functions. + */ + static final Map, UnresolvedFunction>> PPL_TO_SPARK_FUNC_MAPPING + = ImmutableMap., UnresolvedFunction>>builder() + // json functions + .put( + JSON_ARRAY, + args -> { + return UnresolvedFunction$.MODULE$.apply("array", seq(args), false); + }) + .put( + JSON_OBJECT, + args -> { + return UnresolvedFunction$.MODULE$.apply("named_struct", seq(args), false); + }) + .put( + JSON_ARRAY_LENGTH, + args -> { + // Check if the input is an array (from json_array()) or a JSON string + if (args.get(0) instanceof UnresolvedFunction) { + // Input is a JSON array + return UnresolvedFunction$.MODULE$.apply("json_array_length", + seq(UnresolvedFunction$.MODULE$.apply("to_json", seq(args), false)), false); + } else { + // Input is a JSON string + return UnresolvedFunction$.MODULE$.apply("json_array_length", seq(args.get(0)), false); + } + }) + .put( + JSON, + args -> { + // Check if the input is a named_struct (from json_object()) or a JSON string + if (args.get(0) instanceof UnresolvedFunction) { + return UnresolvedFunction$.MODULE$.apply("to_json", seq(args.get(0)), false); + } else { + return UnresolvedFunction$.MODULE$.apply("get_json_object", + seq(args.get(0), Literal$.MODULE$.apply("$")), false); + } + }) + .put( + JSON_VALID, + args -> { + return UnresolvedFunction$.MODULE$.apply("isnotnull", + seq(UnresolvedFunction$.MODULE$.apply("get_json_object", + seq(args.get(0), Literal$.MODULE$.apply("$")), false)), false); + }) + .build(); + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { // TODO change it when UDF is supported @@ -65,8 +148,16 @@ static Expression builtinFunction(org.opensearch.sql.ast.expression.Function fun throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); } else { BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); - String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING - .getOrDefault(builtin, builtin.getName().getFunctionName()); + String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.get(builtin); + if (name != null) { + // there is a Spark builtin function mapping with the PPL builtin function + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + Function, UnresolvedFunction> alternative = PPL_TO_SPARK_FUNC_MAPPING.get(builtin); + if (alternative != null) { + return alternative.apply(args); + } + name = builtin.getName().getFunctionName(); return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..f5dfc4ec8 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanJsonFunctionsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test json()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json('[1,2,3,{"f1":1,"f2":[5,6]},4]')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "get_json_object", + Seq(Literal("""[1,2,3,{"f1":1,"f2":[5,6]},4]"""), Literal("$")), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_object") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json(json_object('key', array(1, 2, 3)))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_array()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_array(1, 2, 0, -1, 1.1, -0.11)"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(0), Literal(-1), Literal(1.1), Literal(-0.11)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_object() and json_array()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json(json_object('key', json_array(1, 2, 3)))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_array_length(jsonString)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_array_length('[1,2,3]')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction("json_array_length", Seq(Literal("""[1,2,3]""")), isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_array_length(json_array())") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_array_length(json_array(1,2,3))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "json_array_length", + Seq( + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(3)), + isDistinct = false)), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_extract()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_extract('{"a":[{"b":1},{"b":2}]}', '$.a[1].b')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "get_json_object", + Seq(Literal("""{"a":[{"b":1},{"b":2}]}"""), Literal("""$.a[1].b""")), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_keys()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_keys('{"f1":"abc","f2":{"f3":"a","f4":"b"}}')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "json_object_keys", + Seq(Literal("""{"f1":"abc","f2":{"f3":"a","f4":"b"}}""")), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("json_valid()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_valid('[1,2,3,{"f1":1,"f2":[5,6]},4]')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val jsonFunc = + UnresolvedFunction( + "isnotnull", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(Literal("""[1,2,3,{"f1":1,"f2":[5,6]},4]"""), Literal("$")), + isDistinct = false)), + isDistinct = false) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } +} From 6513ead9ae057d1ffec01b4b901c884984ed227d Mon Sep 17 00:00:00 2001 From: YANGDB Date: Fri, 25 Oct 2024 10:26:40 -0700 Subject: [PATCH 3/5] PPL `fieldsummary` command (#766) * add support for FieldSummary - antlr syntax - ast expression builder - ast node builder - catalyst ast builder Signed-off-by: YANGDB * add support for FieldSummary - antlr syntax - ast expression builder - ast node builder - catalyst ast builder Signed-off-by: YANGDB * update sample query fix scala style format Signed-off-by: YANGDB * support spark prior to 3.5 with its extended table identifier (existing table identifier only has 2 parts) Signed-off-by: YANGDB * update union queries based summary Signed-off-by: YANGDB * update scala fmt style Signed-off-by: YANGDB * update scala fmt style Signed-off-by: YANGDB * update query with where clause predicate Signed-off-by: YANGDB * update command and remove the topvalues Signed-off-by: YANGDB * update command docs Signed-off-by: YANGDB * update with comments feedback Signed-off-by: YANGDB * update `FIELD SUMMARY` symbols to the keywordsCanBeId bag of words Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 6 + docs/ppl-lang/ppl-fieldsummary-command.md | 83 ++ .../FlintSparkPPLFieldSummaryITSuite.scala | 751 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 6 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 14 + .../sql/ast/AbstractNodeVisitor.java | 11 + .../sql/ast/expression/FieldList.java | 34 + .../opensearch/sql/ast/tree/FieldSummary.java | 57 ++ .../function/BuiltinFunctionName.java | 2 + .../sql/ppl/CatalystPlanContext.java | 9 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 8 + .../opensearch/sql/ppl/parser/AstBuilder.java | 11 +- .../sql/ppl/parser/AstExpressionBuilder.java | 19 + .../sql/ppl/utils/AggregatorTranslator.java | 30 +- .../sql/ppl/utils/DataTypeTransformer.java | 4 + .../ppl/utils/FieldSummaryTransformer.java | 253 ++++++ ...lPlanFieldSummaryTranslatorTestSuite.scala | 709 +++++++++++++++++ 17 files changed, 1993 insertions(+), 14 deletions(-) create mode 100644 docs/ppl-lang/ppl-fieldsummary-command.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index d161613a6..d22fc7b63 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -33,6 +33,12 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | eval b1 = b + 1 | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) - `source = table | eval b1 = lower(b) | fields - b1,c` (Field `b1` cannot be dropped caused by SPARK-49782) +**Field-Summary** +[See additional command details](ppl-fieldsummary-command.md) +- `source = t | fieldsummary includefields=status_code nulls=false` +- `source = t | fieldsummary includefields= id, status_code, request_path nulls=true` +- `source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true` + **Nested-Fields** - `source = catalog.schema.table1, catalog.schema.table2 | fields A.nested1, B.nested1` - `source = catalog.table | where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield` diff --git a/docs/ppl-lang/ppl-fieldsummary-command.md b/docs/ppl-lang/ppl-fieldsummary-command.md new file mode 100644 index 000000000..468c2046b --- /dev/null +++ b/docs/ppl-lang/ppl-fieldsummary-command.md @@ -0,0 +1,83 @@ +## PPL `fieldsummary` command + +**Description** +Using `fieldsummary` command to : + - Calculate basic statistics for each field (count, distinct count, min, max, avg, stddev, mean ) + - Determine the data type of each field + +**Syntax** + +`... | fieldsummary (nulls=true/false)` + +* command accepts any preceding pipe before the terminal `fieldsummary` command and will take them into account. +* `includefields`: list of all the columns to be collected with statistics into a unified result set +* `nulls`: optional; if the true, include the null values in the aggregation calculations (replace null with zero for numeric values) + +### Example 1: + +PPL query: + + os> source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 2 | 2 | 301 | 403 | 352.0 | 352.0 | 72.12489168102785 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Example 2: + +PPL query: + + os> source = t | fieldsummary includefields= id, status_code, request_path nulls=true + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | Fields | COUNT | COUNT_DISTINCT | MIN | MAX | AVG | MEAN | STDDEV | NUlls | TYPEOF | + |------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "id" | 6 | 6 | 1 | 6 | 3.5 | 3.5 | 1.8708286933869707 | 0 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "status_code" | 4 | 3 | 200 | 403 | 184.0 | 184.0 | 161.16699413961905 | 2 | "int" | + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + | "request_path" | 2 | 2 | /about| /home | 0.0 | 0.0 | 0 | 2 |"string"| + +------------------+-------------+------------+------------+------------+------------+------------+------------+----------------| + +### Additional Info +The actual query is translated into the following SQL-like statement: + +```sql + SELECT + id AS Field, + COUNT(id) AS COUNT, + COUNT(DISTINCT id) AS COUNT_DISTINCT, + MIN(id) AS MIN, + MAX(id) AS MAX, + AVG(id) AS AVG, + MEAN(id) AS MEAN, + STDDEV(id) AS STDDEV, + (COUNT(1) - COUNT(id)) AS Nulls, + TYPEOF(id) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +UNION + SELECT + status_code AS Field, + COUNT(status_code) AS COUNT, + COUNT(DISTINCT status_code) AS COUNT_DISTINCT, + MIN(status_code) AS MIN, + MAX(status_code) AS MAX, + AVG(status_code) AS AVG, + MEAN(status_code) AS MEAN, + STDDEV(status_code) AS STDDEV, + (COUNT(1) - COUNT(status_code)) AS Nulls, + TYPEOF(status_code) AS TYPEOF + FROM + t + GROUP BY + TYPEOF(status_code), status_code; +``` +For each such columns (id, status_code) there will be a unique statement and all the fields will be presented togather in the result using a UNION operator + + +### Limitation: + - `topvalues` option was removed from this command due the possible performance impact of such sub-query. As an alternative one can use the `top` command directly as shown [here](ppl-top-command.md). + diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala new file mode 100644 index 000000000..5a5990001 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFieldSummaryITSuite.scala @@ -0,0 +1,751 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, Expression, Literal, NamedExpression, Not, SortOrder, Subtract} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFieldSummaryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNullableTableHttpLog(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 184.0, 184.0, 161.16699413961905, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=false ") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 4, 3, 200, 403, 276.0, 276.0, 97.1356439899038, 2, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=true + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val frame = sql(s""" + | source = $testTable | where status_code != 200 | fieldsummary includefields= status_code nulls=false + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("status_code", 2, 2, 301, 403, 352.0, 352.0, 72.12489168102785, 0, "int")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=true + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 184.0, 184.0, 161.16699413961905, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", 0.0, 0.0, 0.0, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val frame = sql(s""" + | source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=false + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("id", 6L, 6L, "1", "6", 3.5, 3.5, 1.8708286933869707, 0, "int"), + Row("status_code", 4L, 3L, "200", "403", 276.0, 276.0, 97.1356439899038, 2, "int"), + Row("request_path", 4L, 3L, "/about", "/home", null, null, null, 2, "string")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = + Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = + Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index ed170449a..6138a94a2 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -86,6 +86,12 @@ STR: 'STR'; IP: 'IP'; NUM: 'NUM'; + +// FIELDSUMMARY keywords +FIELDSUMMARY: 'FIELDSUMMARY'; +INCLUDEFIELDS: 'INCLUDEFIELDS'; +NULLS: 'NULLS'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 9686b0139..ae5f14498 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -52,6 +52,7 @@ commands | lookupCommand | renameCommand | fillnullCommand + | fieldsummaryCommand ; searchCommand @@ -59,6 +60,15 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; + +fieldsummaryCommand + : FIELDSUMMARY (fieldsummaryParameter)* + ; + +fieldsummaryParameter + : INCLUDEFIELDS EQUAL fieldList # fieldsummaryIncludeFields + | NULLS EQUAL booleanLiteral # fieldsummaryNulls + ; describeCommand : DESCRIBE tableSourceClause @@ -1088,6 +1098,10 @@ keywordsCanBeId | SPARKLINE | C | DC + // FIELD SUMMARY + | FIELDSUMMARY + | INCLUDEFIELDS + | NULLS // JOIN TYPE | OUTER | INNER diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index c361ded08..5ac54127b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -16,6 +16,8 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -206,6 +208,10 @@ public T visitField(Field node, C context) { return visitChildren(node, context); } + public T visitFieldList(FieldList node, C context) { + return visitChildren(node, context); + } + public T visitQualifiedName(QualifiedName node, C context) { return visitChildren(node, context); } @@ -296,9 +302,14 @@ public T visitExplain(Explain node, C context) { public T visitInSubquery(InSubquery node, C context) { return visitChildren(node, context); } + public T visitFillNull(FillNull fillNull, C context) { return visitChildren(fillNull, context); } + + public T visitFieldSummary(FieldSummary fieldSummary, C context) { + return visitChildren(fieldSummary, context); + } public T visitScalarSubquery(ScalarSubquery node, C context) { return visitChildren(node, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java new file mode 100644 index 000000000..4f6ac5e14 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldList.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** Expression node that includes a list of fields nodes. */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@AllArgsConstructor +public class FieldList extends UnresolvedExpression { + private final List fieldList; + + @Override + public List getChild() { + return ImmutableList.copyOf(fieldList); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldList(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java new file mode 100644 index 000000000..a8072e76b --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FieldSummary.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +public class FieldSummary extends UnresolvedPlan { + private List includeFields; + private boolean includeNull; + private List collect; + private UnresolvedPlan child; + + public FieldSummary(List collect) { + this.collect = collect; + collect.forEach(exp -> { + if (exp instanceof Argument) { + this.includeNull = (boolean) ((Argument)exp).getValue().getValue(); + } + if (exp instanceof AttributeList) { + this.includeFields = ((AttributeList)exp).getAttrList(); + } + }); + } + + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFieldSummary(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f44fe26d8..9e1a9a743 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -164,6 +164,8 @@ public enum BuiltinFunctionName { /** Aggregation Function. */ AVG(FunctionName.of("avg")), + MEAN(FunctionName.of("mean")), + STDDEV(FunctionName.of("stddev")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), MIN(FunctionName.of("min")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 46a016d1a..61762f616 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -154,7 +154,13 @@ public LogicalPlan withProjectedFields(List projectedField this.projectedFields.addAll(projectedFields); return getPlan(); } - + + public LogicalPlan applyBranches(List> plans) { + plans.forEach(plan -> with(plan.apply(planBranches.get(0)))); + planBranches.remove(0); + return getPlan(); + } + /** * append plan with evolving plans branches * @@ -281,4 +287,5 @@ public static Optional findRelation(LogicalPlan plan) { // Return null if no UnresolvedRelation is found return Optional.empty(); } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 902fc72e3..76a7a0c79 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -65,6 +65,7 @@ import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; @@ -85,6 +86,7 @@ import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; @@ -380,6 +382,12 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) { node.getSize(), DataTypes.IntegerType), p)); } + @Override + public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { + fieldSummary.getChild().get(0).accept(this, context); + return FieldSummaryTransformer.translate(fieldSummary, context); + } + @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { fillNull.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 1c0fe919f..26a8e2278 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -415,8 +416,14 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) groupListBuilder.build()); return aggregation; } - - /** Rare command. */ + + /** Fieldsummary command. */ + @Override + public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryCommandContext ctx) { + return new FieldSummary(ctx.fieldsummaryParameter().stream().map(arg -> expressionBuilder.visit(arg)).collect(Collectors.toList())); + } + + /** Rare command. */ @Override public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 3b98edd77..ea51ca7a1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -16,11 +16,13 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; @@ -39,6 +41,7 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -50,6 +53,8 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; @@ -179,6 +184,20 @@ public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ArgumentFactory.getArgumentList(ctx)); } + @Override + public UnresolvedExpression visitFieldsummaryIncludeFields(OpenSearchPPLParser.FieldsummaryIncludeFieldsContext ctx) { + List list = ctx.fieldList().fieldExpression().stream() + .map(this::visitFieldExpression) + .collect(Collectors.toList()); + return new AttributeList(list); + } + + @Override + public UnresolvedExpression visitFieldsummaryNulls(OpenSearchPPLParser.FieldsummaryNullsContext ctx) { + return new Argument("NULLS",(Literal)visitBooleanLiteral(ctx.booleanLiteral())); + } + + /** * Aggregation function. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index 3c367a948..a01b38a80 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -12,10 +12,12 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; +import java.util.Optional; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -26,31 +28,37 @@ * @return */ public interface AggregatorTranslator { - + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); + boolean distinct = aggregateFunction.getDistinct(); // Additional aggregation function operators will be added here - switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { + BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get(); + switch (functionName) { case MAX: - return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false); case MIN: - return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false); + case MEAN: + return new UnresolvedFunction(seq("MEAN"), seq(arg), distinct, empty(),false); case AVG: - return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false); case COUNT: - return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false); case SUM: - return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false); + case STDDEV: + return new UnresolvedFunction(seq("STDDEV"), seq(arg), distinct, empty(),false); case STDDEV_POP: - return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false); case STDDEV_SAMP: - return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), distinct, empty(),false); case PERCENTILE: - return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: - return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 4345b0897..62eef90ed 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -20,7 +20,10 @@ import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; +import java.util.Arrays; import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -41,6 +44,7 @@ public interface DataTypeTransformer { static Seq seq(T... elements) { return seq(List.of(elements)); } + static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java new file mode 100644 index 000000000..dd8f01874 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/FieldSummaryTransformer.java @@ -0,0 +1,253 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.AliasIdentifier; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Alias; +import org.apache.spark.sql.catalyst.expressions.Alias$; +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.Subtract; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LocalLimit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.Sort; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.FieldSummary; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystPlanContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MEAN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface FieldSummaryTransformer { + + String TOP_VALUES = "TopValues"; + String NULLS = "Nulls"; + String FIELD = "Field"; + + /** + * translate the command into the aggregate statement group by the column name + */ + static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext context) { + List> aggBranches = fieldSummary.getIncludeFields().stream() + .filter(field -> field instanceof org.opensearch.sql.ast.expression.Field ) + .map(field -> { + Literal fieldNameLiteral = Literal.create(((org.opensearch.sql.ast.expression.Field)field).getField().toString(), StringType); + UnresolvedAttribute fieldLiteral = new UnresolvedAttribute(seq(((org.opensearch.sql.ast.expression.Field)field).getField().getParts())); + context.withProjectedFields(Collections.singletonList(field)); + + // Alias for the field name as Field + Alias fieldNameAlias = Alias$.MODULE$.apply(fieldNameLiteral, + FIELD, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(field) as Count + UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(count, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the count(DISTINCT field) as CountDistinct + UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false); + Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct, + "DISTINCT", + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MAX(field) as MAX + UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false); + Alias maxAlias = Alias$.MODULE$.apply(max, + MAX.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the MIN(field) as Min + UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false); + Alias minAlias = Alias$.MODULE$.apply(min, + MIN.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Alias for the AVG(field) as Avg + Alias avgAlias = getAggMethodAlias(AVG, fieldSummary, fieldLiteral); + + //Alias for the MEAN(field) as Mean + Alias meanAlias = getAggMethodAlias(MEAN, fieldSummary, fieldLiteral); + + //Alias for the STDDEV(field) as Stddev + Alias stddevAlias = getAggMethodAlias(STDDEV, fieldSummary, fieldLiteral); + + // Alias COUNT(*) - COUNT(column2) AS Nulls + UnresolvedFunction countStar = new UnresolvedFunction(seq(COUNT.name()), seq(Literal.create(1, IntegerType)), false, empty(), false); + Alias nonNullAlias = Alias$.MODULE$.apply( + new Subtract(countStar, count), + NULLS, + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + + //Alias for the typeOf(field) as Type + UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false); + Alias typeOfAlias = Alias$.MODULE$.apply(typeOf, + TYPEOF.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + //Aggregation + return (Function) p -> + new Aggregate(seq(typeOfAlias), seq(fieldNameAlias, countAlias, distinctCountAlias, minAlias, maxAlias, avgAlias, meanAlias, stddevAlias, nonNullAlias, typeOfAlias), p); + }).collect(Collectors.toList()); + + return context.applyBranches(aggBranches); + } + + /** + * Alias for aggregate function (if isIncludeNull use COALESCE to replace nulls with zeros) + */ + private static Alias getAggMethodAlias(BuiltinFunctionName method, FieldSummary fieldSummary, UnresolvedAttribute fieldLiteral) { + UnresolvedFunction avg = new UnresolvedFunction(seq(method.name()), seq(fieldLiteral), false, empty(), false); + Alias avgAlias = Alias$.MODULE$.apply(avg, + method.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + + if (fieldSummary.isIncludeNull()) { + UnresolvedFunction coalesceExpr = new UnresolvedFunction( + seq("COALESCE"), + seq(fieldLiteral, Literal.create(0, DataTypes.IntegerType)), + false, + empty(), + false + ); + avg = new UnresolvedFunction(seq(method.name()), seq(coalesceExpr), false, empty(), false); + avgAlias = Alias$.MODULE$.apply(avg, + method.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + } + return avgAlias; + } + + /** + * top values sub-query + */ + private static Alias topValuesSubQueryAlias(FieldSummary fieldSummary, CatalystPlanContext context, UnresolvedAttribute fieldLiteral, UnresolvedFunction count) { + int topValues = 5;// this value should come from the FieldSummary definition + CreateNamedStruct structExpr = new CreateNamedStruct(seq( + fieldLiteral, + count + )); + // Alias COLLECT_LIST(STRUCT(field, COUNT(field))) AS top_values + UnresolvedFunction collectList = new UnresolvedFunction( + seq("COLLECT_LIST"), + seq(structExpr), + false, + empty(), + !fieldSummary.isIncludeNull() + ); + Alias topValuesAlias = Alias$.MODULE$.apply( + collectList, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + Project subQueryProject = new Project(seq(topValuesAlias), buildTopValueSubQuery(topValues, fieldLiteral, context)); + ScalarSubquery scalarSubquery = ScalarSubquery$.MODULE$.apply( + subQueryProject, + seq(new ArrayList()), + NamedExpression.newExprId(), + seq(new ArrayList()), + empty(), + empty()); + + return Alias$.MODULE$.apply( + scalarSubquery, + TOP_VALUES, + NamedExpression.newExprId(), + seq(), + empty(), + seq() + ); + } + + /** + * inner top values query + * ----------------------------------------------------- + * : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, count_status)), None)] + * : : +- 'GlobalLimit 5 + * : : +- 'LocalLimit 5 + * : : +- 'Sort ['count_status DESC NULLS LAST], true + * : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27] + * : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + private static LogicalPlan buildTopValueSubQuery(int topValues, UnresolvedAttribute fieldLiteral, CatalystPlanContext context) { + //Alias for the count(field) as Count + UnresolvedFunction countFunc = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false); + Alias countAlias = Alias$.MODULE$.apply(countFunc, + COUNT.name(), + NamedExpression.newExprId(), + seq(), + empty(), + seq()); + Aggregate aggregate = new Aggregate(seq(fieldLiteral), seq(countAlias), context.getPlan()); + Project project = new Project(seq(fieldLiteral, countAlias), aggregate); + SortOrder sortOrder = new SortOrder(countAlias, Descending$.MODULE$, Ascending$.MODULE$.defaultNullOrdering(), seq()); + Sort sort = new Sort(seq(sortOrder), true, project); + GlobalLimit limit = new GlobalLimit(Literal.create(topValues, IntegerType), new LocalLimit(Literal.create(topValues, IntegerType), sort)); + return new SubqueryAlias(new AliasIdentifier(TOP_VALUES + "_subquery"), limit); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala new file mode 100644 index 000000000..c14e1f6cf --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFieldSummaryTranslatorTestSuite.scala @@ -0,0 +1,709 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, NamedExpression, Not, Subtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project, Union} + +class PPLLogicalPlanFieldSummaryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test fieldsummary with single field includefields(status_code) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test("test fieldsummary with single field includefields(status_code) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = t | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + table) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=true with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(status_code) & nulls=false with a where filter ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | where status_code != 200 | fieldsummary includefields= status_code nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + val aggregateExpressions: Seq[NamedExpression] = Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()) + + val filterCondition = Not(EqualTo(UnresolvedAttribute("status_code"), Literal(200))) + val aggregatePlan = Aggregate( + groupingExpressions = Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + aggregateExpressions, + Filter(filterCondition, table)) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=true"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("id"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("status_code"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction( + "AVG", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq( + UnresolvedFunction( + "COALESCE", + Seq(UnresolvedAttribute("request_path"), Literal(0)), + isDistinct = false)), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test fieldsummary with single field includefields(id, status_code, request_path) & nulls=false") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = t | fieldsummary includefields= id, status_code, request_path nulls=false"), + context) + + // Define the table + val table = UnresolvedRelation(Seq("t")) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateIdPlan = Aggregate( + Seq( + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("id"), "Field")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("id")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction("STDDEV", Seq(UnresolvedAttribute("id")), isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false), + "TYPEOF")()), + table) + val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan) + + // Aggregate with functions applied to status_code + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregateStatusCodePlan = Aggregate( + Seq(Alias( + UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("status_code"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction("MEAN", Seq(UnresolvedAttribute("status_code")), isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "TYPEOF")()), + table) + val statusCodeProj = Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan) + + // Define the aggregate plan with alias for TYPEOF in the aggregation + val aggregatePlan = Aggregate( + Seq( + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + Seq( + Alias(Literal("request_path"), "Field")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "COUNT")(), + Alias( + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = true), + "DISTINCT")(), + Alias( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MIN")(), + Alias( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "MAX")(), + Alias( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false), + "AVG")(), + Alias( + UnresolvedFunction( + "MEAN", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "MEAN")(), + Alias( + UnresolvedFunction( + "STDDEV", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "STDDEV")(), + Alias( + Subtract( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + UnresolvedFunction( + "COUNT", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false)), + "Nulls")(), + Alias( + UnresolvedFunction( + "TYPEOF", + Seq(UnresolvedAttribute("request_path")), + isDistinct = false), + "TYPEOF")()), + table) + val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan) + + val expectedPlan = Union(seq(idProj, statusCodeProj, requestPathProj), true, true) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } +} From 2a647d43264cdcbb65f95b22c8373dee50497937 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Fri, 25 Oct 2024 10:26:58 -0700 Subject: [PATCH 4/5] Add read/write bytes metrics (#803) * Add read/write bytes metrics Signed-off-by: Tomoyuki Morita * Add unit test Signed-off-by: Tomoyuki Morita * Address comments Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita --- build.sbt | 1 + .../flint/core/metrics/HistoricGauge.java | 63 +++++++++++++++ .../flint/core/metrics/MetricConstants.java | 20 +++++ .../flint/core/metrics/MetricsUtil.java | 27 +++++++ .../DimensionedCloudWatchReporter.java | 31 +++++++- .../metrics/ReadWriteBytesSparkListener.scala | 58 ++++++++++++++ .../flint/core/metrics/HistoricGaugeTest.java | 79 +++++++++++++++++++ .../flint/core/metrics/MetricsUtilTest.java | 30 ++++++- .../spark/refresh/AutoIndexRefresh.scala | 21 ++--- .../org/apache/spark/sql/FlintREPL.scala | 18 +++-- .../org/apache/spark/sql/JobOperator.scala | 12 ++- 11 files changed, 337 insertions(+), 23 deletions(-) create mode 100644 flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala create mode 100644 flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java diff --git a/build.sbt b/build.sbt index f7653c50c..30858e8d6 100644 --- a/build.sbt +++ b/build.sbt @@ -89,6 +89,7 @@ lazy val flintCore = (project in file("flint-core")) "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), "software.amazon.awssdk" % "auth-crt" % "2.28.10", + "org.projectlombok" % "lombok" % "1.18.30" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java new file mode 100644 index 000000000..181bf8575 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import com.codahale.metrics.Gauge; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Value; + +/** + * Gauge which stores historic data points with timestamps. + * This is used for emitting separate data points per request, instead of single aggregated metrics. + */ +public class HistoricGauge implements Gauge { + @AllArgsConstructor + @Value + public static class DataPoint { + Long value; + long timestamp; + } + + private final List dataPoints = Collections.synchronizedList(new LinkedList<>()); + + /** + * This method will just return first value. + * @return first value + */ + @Override + public Long getValue() { + if (!dataPoints.isEmpty()) { + return dataPoints.get(0).value; + } else { + return null; + } + } + + /** + * Add new data point. Current time stamp will be attached to the data point. + * @param value metric value + */ + public void addDataPoint(Long value) { + dataPoints.add(new DataPoint(value, System.currentTimeMillis())); + } + + /** + * Return copy of dataPoints and remove them from internal list + * @return copy of the data points + */ + public List pollDataPoints() { + int size = dataPoints.size(); + List result = new ArrayList<>(dataPoints.subList(0, size)); + if (size > 0) { + dataPoints.subList(0, size).clear(); + } + return result; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 3a72c1d5a..427fab9fe 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -117,6 +117,26 @@ public final class MetricConstants { */ public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + /** + * Metric for tracking the total bytes read from input + */ + public static final String INPUT_TOTAL_BYTES_READ = "input.totalBytesRead.count"; + + /** + * Metric for tracking the total records read from input + */ + public static final String INPUT_TOTAL_RECORDS_READ = "input.totalRecordsRead.count"; + + /** + * Metric for tracking the total bytes written to output + */ + public static final String OUTPUT_TOTAL_BYTES_WRITTEN = "output.totalBytesWritten.count"; + + /** + * Metric for tracking the total records written to output + */ + public static final String OUTPUT_TOTAL_RECORDS_WRITTEN = "output.totalRecordsWritten.count"; + private MetricConstants() { // Private constructor to prevent instantiation } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index ab1207ccc..511c18664 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -75,6 +75,15 @@ public static void decrementCounter(String metricName, boolean isIndexMetric) { } } + public static void setCounter(String metricName, boolean isIndexMetric, long n) { + Counter counter = getOrCreateCounter(metricName, isIndexMetric); + if (counter != null) { + counter.dec(counter.getCount()); + counter.inc(n); + LOG.info("counter: " + counter.getCount()); + } + } + /** * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. * @@ -111,6 +120,24 @@ public static Timer getTimer(String metricName, boolean isIndexMetric) { return getOrCreateTimer(metricName, isIndexMetric); } + /** + * Registers a HistoricGauge metric with the provided name and value. + * + * @param metricName The name of the HistoricGauge metric to register. + * @param value The value to be stored + */ + public static void addHistoricGauge(String metricName, final long value) { + HistoricGauge historicGauge = getOrCreateHistoricGauge(metricName); + if (historicGauge != null) { + historicGauge.addDataPoint(value); + } + } + + private static HistoricGauge getOrCreateHistoricGauge(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(false); + return metricRegistry != null ? metricRegistry.gauge(metricName, HistoricGauge::new) : null; + } + /** * Registers a gauge metric with the provided name and value. * diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java index a5ea190c5..9104e1b34 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -47,6 +47,7 @@ import java.util.stream.LongStream; import java.util.stream.Stream; import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups; +import org.opensearch.flint.core.metrics.HistoricGauge; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -145,7 +146,11 @@ public void report(final SortedMap gauges, gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size()); for (final Map.Entry gaugeEntry : gauges.entrySet()) { - processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + if (gaugeEntry.getValue() instanceof HistoricGauge) { + processHistoricGauge(gaugeEntry.getKey(), (HistoricGauge) gaugeEntry.getValue(), metricData); + } else { + processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + } } for (final Map.Entry counterEntry : counters.entrySet()) { @@ -227,6 +232,13 @@ private void processGauge(final String metricName, final Gauge gauge, final List } } + private void processHistoricGauge(final String metricName, final HistoricGauge gauge, final List metricData) { + for (HistoricGauge.DataPoint dataPoint: gauge.pollDataPoints()) { + stageMetricDatum(true, metricName, dataPoint.getValue().doubleValue(), StandardUnit.None, DIMENSION_GAUGE, metricData, + dataPoint.getTimestamp()); + } + } + private void processCounter(final String metricName, final Counting counter, final List metricData) { long currentCount = counter.getCount(); Long lastCount = lastPolledCounts.get(counter); @@ -333,12 +345,25 @@ private void processHistogram(final String metricName, final Histogram histogram *

* If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted */ + private void stageMetricDatum(final boolean metricConfigured, + final String metricName, + final double metricValue, + final StandardUnit standardUnit, + final String dimensionValue, + final List metricData + ) { + stageMetricDatum(metricConfigured, metricName, metricValue, standardUnit, + dimensionValue, metricData, builder.clock.getTime()); + } + private void stageMetricDatum(final boolean metricConfigured, final String metricName, final double metricValue, final StandardUnit standardUnit, final String dimensionValue, - final List metricData) { + final List metricData, + final Long timestamp + ) { // Only submit metrics that show some data, so let's save some money if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) { final DimensionedName dimensionedName = DimensionedName.decode(metricName); @@ -351,7 +376,7 @@ private void stageMetricDatum(final boolean metricConfigured, MetricInfo metricInfo = getMetricInfo(dimensionedName, dimensions); for (Set dimensionSet : metricInfo.getDimensionSets()) { MetricDatum datum = new MetricDatum() - .withTimestamp(new Date(builder.clock.getTime())) + .withTimestamp(new Date(timestamp)) .withValue(cleanMetricValue(metricValue)) .withMetricName(metricInfo.getMetricName()) .withDimensions(dimensionSet) diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala new file mode 100644 index 000000000..bfafd3eb3 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.SparkSession + +/** + * Collect and emit bytesRead/Written and recordsRead/Written metrics + */ +class ReadWriteBytesSparkListener extends SparkListener with Logging { + var bytesRead: Long = 0 + var recordsRead: Long = 0 + var bytesWritten: Long = 0 + var recordsWritten: Long = 0 + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val inputMetrics = taskEnd.taskMetrics.inputMetrics + val outputMetrics = taskEnd.taskMetrics.outputMetrics + val ids = s"(${taskEnd.taskInfo.taskId}, ${taskEnd.taskInfo.partitionId})" + logInfo( + s"${ids} Input: bytesRead=${inputMetrics.bytesRead}, recordsRead=${inputMetrics.recordsRead}") + logInfo( + s"${ids} Output: bytesWritten=${outputMetrics.bytesWritten}, recordsWritten=${outputMetrics.recordsWritten}") + + bytesRead += inputMetrics.bytesRead + recordsRead += inputMetrics.recordsRead + bytesWritten += outputMetrics.bytesWritten + recordsWritten += outputMetrics.recordsWritten + } + + def emitMetrics(): Unit = { + logInfo(s"Input: totalBytesRead=${bytesRead}, totalRecordsRead=${recordsRead}") + logInfo(s"Output: totalBytesWritten=${bytesWritten}, totalRecordsWritten=${recordsWritten}") + MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_BYTES_READ, bytesRead) + MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_RECORDS_READ, recordsRead) + MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_BYTES_WRITTEN, bytesWritten) + MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_RECORDS_WRITTEN, recordsWritten) + } +} + +object ReadWriteBytesSparkListener { + def withMetrics[T](spark: SparkSession, lambda: () => T): T = { + val listener = new ReadWriteBytesSparkListener() + spark.sparkContext.addSparkListener(listener) + + val result = lambda() + + spark.sparkContext.removeSparkListener(listener) + listener.emitMetrics() + + result + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java new file mode 100644 index 000000000..f3d842af2 --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import org.junit.Test; +import static org.junit.Assert.*; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; + +import java.util.List; + +public class HistoricGaugeTest { + + @Test + public void testGetValue_EmptyGauge_ShouldReturnNull() { + HistoricGauge gauge= new HistoricGauge(); + assertNull(gauge.getValue()); + } + + @Test + public void testGetValue_WithSingleDataPoint_ShouldReturnFirstValue() { + HistoricGauge gauge= new HistoricGauge(); + Long value = 100L; + gauge.addDataPoint(value); + + assertEquals(value, gauge.getValue()); + } + + @Test + public void testGetValue_WithMultipleDataPoints_ShouldReturnFirstValue() { + HistoricGauge gauge= new HistoricGauge(); + Long firstValue = 100L; + Long secondValue = 200L; + gauge.addDataPoint(firstValue); + gauge.addDataPoint(secondValue); + + assertEquals(firstValue, gauge.getValue()); + } + + @Test + public void testPollDataPoints_WithMultipleDataPoints_ShouldReturnAndClearDataPoints() { + HistoricGauge gauge= new HistoricGauge(); + gauge.addDataPoint(100L); + gauge.addDataPoint(200L); + gauge.addDataPoint(300L); + + List dataPoints = gauge.pollDataPoints(); + + assertEquals(3, dataPoints.size()); + assertEquals(Long.valueOf(100L), dataPoints.get(0).getValue()); + assertEquals(Long.valueOf(200L), dataPoints.get(1).getValue()); + assertEquals(Long.valueOf(300L), dataPoints.get(2).getValue()); + + assertTrue(gauge.pollDataPoints().isEmpty()); + } + + @Test + public void testAddDataPoint_ShouldAddDataPointWithCorrectValueAndTimestamp() { + HistoricGauge gauge= new HistoricGauge(); + Long value = 100L; + gauge.addDataPoint(value); + + List dataPoints = gauge.pollDataPoints(); + + assertEquals(1, dataPoints.size()); + assertEquals(value, dataPoints.get(0).getValue()); + assertTrue(dataPoints.get(0).getTimestamp() > 0); + } + + @Test + public void testPollDataPoints_EmptyGauge_ShouldReturnEmptyList() { + HistoricGauge gauge= new HistoricGauge(); + List dataPoints = gauge.pollDataPoints(); + + assertTrue(dataPoints.isEmpty()); + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index b5470b6be..70b51ed63 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -4,7 +4,7 @@ import com.codahale.metrics.Gauge; import com.codahale.metrics.Timer; import java.time.Duration; -import java.time.temporal.TemporalUnit; +import java.util.List; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.FlintIndexMetricSource; @@ -16,6 +16,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -199,4 +200,31 @@ public void testDefaultBehavior() { Assertions.assertNotNull(flintMetricSource.metricRegistry().getGauges().get(testGaugeMetric)); } } + + @Test + public void testAddHistoricGauge() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + String sourceName = FlintMetricSource.FLINT_METRIC_SOURCE_NAME(); + Source metricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource); + + long value1 = 100L; + long value2 = 200L; + String gaugeName = "test.gauge"; + MetricsUtil.addHistoricGauge(gaugeName, value1); + MetricsUtil.addHistoricGauge(gaugeName, value2); + + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(metricSource, times(2)).metricRegistry(); + + HistoricGauge gauge = (HistoricGauge)metricSource.metricRegistry().getGauges().get(gaugeName); + Assertions.assertNotNull(gauge); + List dataPoints = gauge.pollDataPoints(); + Assertions.assertEquals(value1, dataPoints.get(0).getValue()); + Assertions.assertEquals(value2, dataPoints.get(1).getValue()); + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala index d343fd999..bedeeba54 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.refresh import java.util.Collections +import org.opensearch.flint.core.metrics.ReadWriteBytesSparkListener import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper} import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode} @@ -67,15 +68,17 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) // Flint index has specialized logic and capability for incremental refresh case refresh: StreamingRefresh => logInfo("Start refreshing index in streaming style") - val job = - refresh - .buildStream(spark) - .writeStream - .queryName(indexName) - .format(FLINT_DATASOURCE) - .options(flintSparkConf.properties) - .addSinkOptions(options, flintSparkConf) - .start(indexName) + val job = ReadWriteBytesSparkListener.withMetrics( + spark, + () => + refresh + .buildStream(spark) + .writeStream + .queryName(indexName) + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) + .addSinkOptions(options, flintSparkConf) + .start(indexName)) Some(job.id.toString) // Otherwise, fall back to foreachBatch + batch refresh diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index cdeebe663..0978e6898 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -17,7 +17,7 @@ import com.codahale.metrics.Timer import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.{MetricConstants, ReadWriteBytesSparkListener} import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.apache.spark.SparkConf @@ -525,12 +525,16 @@ object FlintREPL extends Logging with FlintJobExecutor { val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = - processStatementOnVerification( - statementExecutionManager, - queryResultWriter, - flintStatement, - state, - context) + ReadWriteBytesSparkListener.withMetrics( + spark, + () => { + processStatementOnVerification( + statementExecutionManager, + queryResultWriter, + flintStatement, + state, + context) + }) verificationResult = returnedVerificationResult finalizeCommand( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 01d8cb05c..6cdbdb16d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,7 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.common.scheduler.model.LangType -import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil, ReadWriteBytesSparkListener} import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark @@ -70,6 +70,9 @@ case class JobOperator( val statementExecutionManager = instantiateStatementExecutionManager(commandContext, resultIndex, osClient) + val readWriteBytesSparkListener = new ReadWriteBytesSparkListener() + sparkSession.sparkContext.addSparkListener(readWriteBytesSparkListener) + val statement = new FlintStatement( "running", @@ -137,6 +140,8 @@ case class JobOperator( startTime)) } finally { emitQueryExecutionTimeMetric(startTime) + readWriteBytesSparkListener.emitMetrics() + sparkSession.sparkContext.removeSparkListener(readWriteBytesSparkListener) try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) @@ -202,8 +207,9 @@ case class JobOperator( private def emitQueryExecutionTimeMetric(startTime: Long): Unit = { MetricsUtil - .getTimer(MetricConstants.QUERY_EXECUTION_TIME_METRIC, false) - .update(System.currentTimeMillis() - startTime, TimeUnit.MILLISECONDS); + .addHistoricGauge( + MetricConstants.QUERY_EXECUTION_TIME_METRIC, + System.currentTimeMillis() - startTime) } def stop(): Unit = { From 7bc09278d8a9d7df05ff2e086e7f430a4a0da33c Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Sat, 26 Oct 2024 01:27:10 +0800 Subject: [PATCH 5/5] Support `Eventstats` in PPL (#800) * Support Eventstats in PPL Signed-off-by: Lantao Jin * add doc Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin Co-authored-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 28 ++ docs/ppl-lang/README.md | 2 + docs/ppl-lang/ppl-eventstats-command.md | 327 +++++++++++++++ .../ppl/FlintSparkPPLEventstatsITSuite.scala | 379 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 +- .../sql/ast/AbstractNodeVisitor.java | 4 + .../org/opensearch/sql/ast/tree/Window.java | 45 +++ .../sql/ppl/CatalystQueryPlanVisitor.java | 27 ++ .../opensearch/sql/ppl/parser/AstBuilder.java | 26 +- .../sql/ppl/utils/WindowSpecTransformer.java | 19 + ...calPlanEventstatsTranslatorTestSuite.scala | 256 ++++++++++++ 12 files changed, 1107 insertions(+), 9 deletions(-) create mode 100644 docs/ppl-lang/ppl-eventstats-command.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index d22fc7b63..8e796c6fb 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -173,6 +173,34 @@ source = table | where ispresent(a) | - `source = table | stats avg(age) as avg_state_age by country, state | stats avg(avg_state_age) as avg_country_age by country` - `source = table | stats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | stats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | stats avg(avg_state_age) as avg_adult_country_age by country` +#### **Event Aggregations** +[See additional command details](ppl-eventstats-command.md) + +- `source = table | eventstats avg(a) ` +- `source = table | where a < 50 | eventstats avg(c) ` +- `source = table | eventstats max(c) by b` +- `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats stddev_samp(c)` +- `source = table | eventstats stddev_pop(c)` +- `source = table | eventstats percentile(c, 90)` +- `source = table | eventstats percentile_approx(c, 99)` + +**Limitation: distinct aggregation could not used in `eventstats`:**_ +- `source = table | eventstats distinct_count(c)` (throw exception) + +**Aggregations With Span** +- `source = table | eventstats count(a) by span(a, 10) as a_span` +- `source = table | eventstats sum(age) by span(age, 5) as age_span | head 2` +- `source = table | eventstats avg(age) by span(age, 20) as age_span, country | sort - age_span | head 2` + +**Aggregations With TimeWindow Span (tumble windowing function)** +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` + +**Aggregations Group by Multiple Times** +- `source = table | eventstats avg(age) as avg_state_age by country, state | eventstats avg(avg_state_age) as avg_country_age by country` +- `source = table | eventstats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | eventstats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | eventstats avg(avg_state_age) as avg_adult_country_age by country` + #### **Dedup** [See additional command details](ppl-dedup-command.md) diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 4fa9d10cc..9cb5f118e 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -50,6 +50,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`stats command`](ppl-stats-command.md) + - [`eventstats command`](ppl-eventstats-command.md) + - [`where command`](ppl-where-command.md) - [`head command`](ppl-head-command.md) diff --git a/docs/ppl-lang/ppl-eventstats-command.md b/docs/ppl-lang/ppl-eventstats-command.md new file mode 100644 index 000000000..9a65d5052 --- /dev/null +++ b/docs/ppl-lang/ppl-eventstats-command.md @@ -0,0 +1,327 @@ +## PPL `eventstats` command + +### Description +The `eventstats` command enriches your event data with calculated summary statistics. It operates by analyzing specified fields within your events, computing various statistical measures, and then appending these results as new fields to each original event. + +Key aspects of `eventstats`: + +1. It performs calculations across the entire result set or within defined groups. +2. The original events remain intact, with new fields added to contain the statistical results. +3. The command is particularly useful for comparative analysis, identifying outliers, or providing additional context to individual events. + +### Difference between [`stats`](ppl-stats-command.md) and `eventstats` +The `stats` and `eventstats` commands are both used for calculating statistics, but they have some key differences in how they operate and what they produce: + +- Output Format: + - `stats`: Produces a summary table with only the calculated statistics. + - `eventstats`: Adds the calculated statistics as new fields to the existing events, preserving the original data. +- Event Retention: + - `stats`: Reduces the result set to only the statistical summary, discarding individual events. + - `eventstats`: Retains all original events and adds new fields with the calculated statistics. +- Use Cases: + - `stats`: Best for creating summary reports or dashboards. Often used as a final command to summarize results. + - `eventstats`: Useful when you need to enrich events with statistical context for further analysis or filtering. Can be used mid-search to add statistics that can be used in subsequent commands. + +### Syntax +`eventstats ... [by-clause]` + +### **aggregation:** +mandatory. A aggregation function. The argument of aggregation must be field. + +**by-clause**: optional. + +#### Syntax: +`by [span-expression,] [field,]...` + +**Description:** + +The by clause could be the fields and expressions like scalar functions and aggregation functions. +Besides, the span clause can be used to split specific field into buckets in the same interval, the eventstats then does the aggregation by these span buckets. + +**Default**: + +If no `` is specified, the eventstats command aggregates over the entire result set. + +### **`span-expression`**: +optional, at most one. + +#### Syntax: +`span(field_expr, interval_expr)` + +**Description:** + +The unit of the interval expression is the natural unit by default. +If the field is a date and time type field, and the interval is in date/time units, you will need to specify the unit in the interval expression. + +For example, to split the field ``age`` into buckets by 10 years, it looks like ``span(age, 10)``. And here is another example of time span, the span to split a ``timestamp`` field into hourly intervals, it looks like ``span(timestamp, 1h)``. + +* Available time unit: +``` ++----------------------------+ +| Span Interval Units | ++============================+ +| millisecond (ms) | ++----------------------------+ +| second (s) | ++----------------------------+ +| minute (m, case sensitive) | ++----------------------------+ +| hour (h) | ++----------------------------+ +| day (d) | ++----------------------------+ +| week (w) | ++----------------------------+ +| month (M, case sensitive) | ++----------------------------+ +| quarter (q) | ++----------------------------+ +| year (y) | ++----------------------------+ +``` + +### Aggregation Functions + +#### _COUNT_ + +**Description** + +Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example: + + os> source=accounts | eventstats count(); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+---------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | count() | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+---------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 4 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 4 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 4 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 4 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+---------+ + + +#### _SUM_ + +**Description** + +`SUM(expr)`. Returns the sum of expr. + +Example: + + os> source=accounts | eventstats sum(age) by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | sum(age) by gender | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 101 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 101 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 101 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + + +#### _AVG_ + +**Description** + +`AVG(expr)`. Returns the average value of expr. + +Example: + + os> source=accounts | eventstats avg(age) by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | avg(age) by gender | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 33.67 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 33.67 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28.00 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 33.67 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------+ + + +#### MAX + +**Description** + +`MAX(expr)` Returns the maximum value of expr. + +Example: + + os> source=accounts | eventstats max(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | max(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 36 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 36 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 36 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 36 | ++----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + + +#### MIN + +**Description** + +`MIN(expr)` Returns the minimum value of expr. + +Example: + + os> source=accounts | eventstats min(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | min(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 28 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 28 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 28 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+ + + +#### STDDEV_SAMP + +**Description** + +`STDDEV_SAMP(expr)` Return the sample standard deviation of expr. + +Example: + + os> source=accounts | eventstats stddev_samp(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | stddev_samp(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 3.304037933599835 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 3.304037933599835 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 3.304037933599835 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 3.304037933599835 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + + +#### STDDEV_POP + +**Description** + +`STDDEV_POP(expr)` Return the population standard deviation of expr. + +Example: + + os> source=accounts | eventstats stddev_pop(age); + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | stddev_pop(age) | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 2.8613807855648994 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 2.8613807855648994 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 2.8613807855648994 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 2.8613807855648994 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+------------------------+ + + +#### PERCENTILE or PERCENTILE_APPROX + +**Description** + +`PERCENTILE(expr, percent)` or `PERCENTILE_APPROX(expr, percent)` Return the approximate percentile value of expr at the specified percentage. + +* percent: The number must be a constant between 0 and 100. +--- + +Examples: + + os> source=accounts | eventstats percentile(age, 90) by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------------------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | percentile(age, 90) by gender | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------------------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 36 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 36 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 36 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+--------------------------------+ + + +### Example 1: Calculate the average, sum and count of a field by group + +The example show calculate the average age, sum age and count of events of all the accounts group by gender. + +PPL query: + + os> source=accounts | eventstats avg(age) as avg_age, sum(age) as sum_age, count() as count by gender; + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+-----------+-------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | avg_age | sum_age | count | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+-----------+-------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 33.666667 | 101 | 3 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 33.666667 | 101 | 3 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 28.000000 | 28 | 1 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 33.666667 | 101 | 3 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----------+-----------+-------+ + + +### Example 2: Calculate the count by a span + +The example gets the count of age by the interval of 10 years. + +PPL query: + + os> source=accounts | eventstats count(age) by span(age, 10) as age_span + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+----------+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | age_span | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+----------+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 3 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 3 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 1 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 3 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+----------+ + + +### Example 3: Calculate the count by a gender and span + +The example gets the count of age by the interval of 5 years and group by gender. + +PPL query: + + os> source=accounts | eventstats count() as cnt by span(age, 5) as age_span, gender + fetched rows / total rows = 4/4 + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----+ + | account_number | balance | firstname | lastname | age | gender | address | employer | email | city | state | cnt | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----+ + | 1 | 39225 | Amber | Duke | 32 | M | 880 Holmes Lane | Pyrami | amberduke@pyrami.com | Brogan | IL | 2 | + | 6 | 5686 | Hattie | Bond | 36 | M | 671 Bristol Street | Netagy | hattiebond@netagy.com | Dante | TN | 1 | + | 13 | 32838 | Nanette | Bates | 28 | F | 789 Madison Street | Quility | | Nogal | VA | 1 | + | 18 | 4180 | Dale | Adams | 33 | M | 467 Hutchinson Court | | daleadams@boink.com | Orick | MD | 2 | + +----------------+----------+-----------+----------+-----+--------+-----------------------+----------+--------------------------+--------+-------+-----+ + + +### Usage +- `source = table | eventstats avg(a) ` +- `source = table | where a < 50 | eventstats avg(c) ` +- `source = table | eventstats max(c) by b` +- `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats distinct_count(c)` +- `source = table | eventstats stddev_samp(c)` +- `source = table | eventstats stddev_pop(c)` +- `source = table | eventstats percentile(c, 90)` +- `source = table | eventstats percentile_approx(c, 99)` + +**Aggregations With Span** +- `source = table | eventstats count(a) by span(a, 10) as a_span` +- `source = table | eventstats sum(age) by span(age, 5) as age_span | head 2` +- `source = table | eventstats avg(age) by span(age, 20) as age_span, country | sort - age_span | head 2` + +**Aggregations With TimeWindow Span (tumble windowing function)** + +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` +- `source = table | eventstats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` + +**Aggregations Group by Multiple Levels** + +- `source = table | eventstats avg(age) as avg_state_age by country, state | eventstats avg(avg_state_age) as avg_country_age by country` +- `source = table | eventstats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | eventstats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | eventstats avg(avg_state_age) as avg_adult_country_age by country` + diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala new file mode 100644 index 000000000..f1d287429 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala @@ -0,0 +1,379 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, Floor, Literal, Multiply, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Window} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLEventstatsITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test eventstats avg") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 36.25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 36.25), + Row("Jake", 70, "California", "USA", 2023, 4, 36.25), + Row("Hello", 30, "New York", "USA", 2023, 4, 36.25)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), Nil, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 36.25, 70, 20, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 36.25, 70, 20, 4), + Row("Jake", 70, "California", "USA", 2023, 4, 36.25, 70, 20, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, 36.25, 70, 20, 4)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + Nil, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg by country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5), + Row("Jake", 70, "California", "USA", 2023, 4, 50), + Row("Hello", 30, "New York", "USA", 2023, 4, 50)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), partitionSpec, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 50, 70, 30, 2), + Row("Hello", 30, "New York", "USA", 2023, 4, 50, 70, 30, 2)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val partitionSpec = Seq(span) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span and country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + } + + test("test eventstats avg, max, min, count by span and state") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, state + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 25, 1), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 20, 1), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + } + + test("test eventstats stddev by span with filter") { + val frame = sql(s""" + | source = $testTable | where country != 'USA' | eventstats stddev_samp(age) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 3.5355339059327378), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 3.5355339059327378)) + assertSameRows(expected, frame) + } + + test("test eventstats stddev_pop by span with filter") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | eventstats stddev_pop(age) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 2.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 2.5), + Row("Hello", 30, "New York", "USA", 2023, 4, 0.0)) + assertSameRows(expected, frame) + } + + test("test eventstats percentile by span with filter") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | eventstats percentile_approx(age, 60) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30)) + assertSameRows(expected, frame) + } + + test("test multiple eventstats") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age by state, country | eventstats avg(avg_age) as avg_state_age by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 22.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20.0, 22.5), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 50.0), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 50.0)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val partitionSpec = Seq( + Alias(UnresolvedAttribute("state"), "state")(), + Alias(UnresolvedAttribute("country"), "country")()) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val avgAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgAgeWindowExprAlias = Alias(avgAgeWindowExpression, "avg_age")() + val windowPlan1 = Window(Seq(avgAgeWindowExprAlias), partitionSpec, Nil, table) + + val countryPartitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgStateAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("avg_age")), isDistinct = false), + WindowSpecDefinition( + countryPartitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgStateAgeWindowExprAlias = Alias(avgStateAgeWindowExpression, "avg_state_age")() + val windowPlan2 = + Window(Seq(avgStateAgeWindowExprAlias), countryPartitionSpec, Nil, windowPlan1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan2) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eventstats with eval") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age by state, country | eval new_avg_age = avg_age - 10 | eventstats avg(new_avg_age) as avg_state_age by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 15.0, 12.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20.0, 10.0, 12.5), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 60.0, 40.0), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 20.0, 40.0)) + assertSameRows(expected, frame) + } + + test("test multiple eventstats with eval and filter") { + val frame = sql(s""" + | source = $testTable| eventstats avg(age) as avg_age by country, state, name | eval avg_age_divide_20 = avg_age - 20 | eventstats avg(avg_age_divide_20) + | as avg_state_age by country, state | where avg_state_age > 0 | eventstats count(avg_state_age) as count_country_age_greater_20 by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 5.0, 5.0, 1), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 50.0, 50.0, 2), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 10.0, 10.0, 2)) + assertSameRows(expected, frame) + } + + test("test eventstats distinct_count by span with filter") { + val exception = intercept[AnalysisException](sql(s""" + | source = $testTable | where state != 'California' | eventstats distinct_count(age) by span(age, 10) as age_span + | """.stripMargin)) + assert(exception.message.contains("Distinct window functions are not supported")) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 6138a94a2..fc92b0a14 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -18,6 +18,7 @@ WHERE: 'WHERE'; FIELDS: 'FIELDS'; RENAME: 'RENAME'; STATS: 'STATS'; +EVENTSTATS: 'EVENTSTATS'; DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index ae5f14498..2d9357ec5 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -125,7 +125,7 @@ renameCommand ; statsCommand - : STATS (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? + : (STATS | EVENTSTATS) (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? ; dedupCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 5ac54127b..e1397a754 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -318,4 +318,8 @@ public T visitScalarSubquery(ScalarSubquery node, C context) { public T visitExistsSubquery(ExistsSubquery node, C context) { return visitChildren(node, context); } + + public T visitWindow(Window node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java new file mode 100644 index 000000000..26cd08831 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class Window extends UnresolvedPlan { + private final List windowFunctionList; + private final List partExprList; + private final List sortExprList; + @Setter private UnresolvedExpression span; + private UnresolvedPlan child; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitWindow(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 76a7a0c79..32ed2c92f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -82,6 +82,7 @@ import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; @@ -89,6 +90,7 @@ import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.WindowSpecTransformer; import scala.Option; import scala.Tuple2; import scala.collection.IterableLike; @@ -117,6 +119,7 @@ import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; +import static scala.collection.JavaConverters.seqAsJavaList; /** * Utility class to traverse PPL logical plan and translate it into catalyst logical plan @@ -328,6 +331,30 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { return context.apply(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); } + @Override + public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); + Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); + List partitionExpList = visitExpressionList(node.getPartExprList(), context); + UnresolvedExpression span = node.getSpan(); + if (!Objects.isNull(span)) { + visitExpression(span, context); + } + Seq partitionSpec = context.retainAllNamedParseExpressions(p -> p); + Seq orderSpec = seq(new ArrayList()); + Seq aggregatorFunctions = seq( + seqAsJavaList(windowFunctionExpressions).stream() + .map(w -> WindowSpecTransformer.buildAggregateWindowFunction(w, partitionSpec, orderSpec)) + .collect(Collectors.toList())); + return context.apply(p -> + new org.apache.spark.sql.catalyst.plans.logical.Window( + aggregatorFunctions, + partitionSpec, + orderSpec, + p)); + } + @Override public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { expressionAnalyzer.visitAlias(node, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 26a8e2278..ed7717188 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -270,14 +270,24 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext .map(this::internalVisitExpression) .orElse(null); - Aggregation aggregation = - new Aggregation( - aggListBuilder.build(), - emptyList(), - groupList, - span, - ArgumentFactory.getArgumentList(ctx)); - return aggregation; + if (ctx.STATS() != null) { + Aggregation aggregation = + new Aggregation( + aggListBuilder.build(), + emptyList(), + groupList, + span, + ArgumentFactory.getArgumentList(ctx)); + return aggregation; + } else { + Window window = + new Window( + aggListBuilder.build(), + groupList, + emptyList()); + window.setSpan(span); + return window; + } } /** Dedup command. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java index 0e6ba2a1d..e6dd12032 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.expressions.Alias; import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -16,6 +17,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; import org.apache.spark.sql.catalyst.expressions.TimeWindow; +import org.apache.spark.sql.catalyst.expressions.UnboundedFollowing$; import org.apache.spark.sql.catalyst.expressions.UnboundedPreceding$; import org.apache.spark.sql.catalyst.expressions.WindowExpression; import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; @@ -79,4 +81,21 @@ static NamedExpression buildRowNumber(Seq partitionSpec, Seq())); } + + static NamedExpression buildAggregateWindowFunction(Expression aggregator, Seq partitionSpec, Seq orderSpec) { + Alias aggregatorAlias = (Alias) aggregator; + WindowExpression aggWindowExpression = new WindowExpression( + aggregatorAlias.child(), + new WindowSpecDefinition( + partitionSpec, + orderSpec, + new SpecifiedWindowFrame(RowFrame$.MODULE$, UnboundedPreceding$.MODULE$, UnboundedFollowing$.MODULE$))); + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + aggWindowExpression, + aggregatorAlias.name(), + NamedExpression.newExprId(), + seq(new ArrayList()), + Option.empty(), + seq(new ArrayList())); + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala new file mode 100644 index 000000000..53bb65950 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala @@ -0,0 +1,256 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, Floor, Literal, Multiply, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Window} + +class PPLLogicalPlanEventstatsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test eventstats avg") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = table | eventstats avg(age)"), context) + + val table = UnresolvedRelation(Seq("table")) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), Nil, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test eventstats avg, max, min, count") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count"), + context) + + val table = UnresolvedRelation(Seq("table")) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + Nil, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg by country") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = table | eventstats avg(age) by country"), + context) + + val table = UnresolvedRelation(Seq("table")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), partitionSpec, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by country") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country"), + context) + + val table = UnresolvedRelation(Seq("table")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span"), + context) + + val table = UnresolvedRelation(Seq("table")) + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val partitionSpec = Seq(span) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple eventstats") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age by state, country | eventstats avg(avg_age) as avg_state_age by country"), + context) + + val partitionSpec = Seq( + Alias(UnresolvedAttribute("state"), "state")(), + Alias(UnresolvedAttribute("country"), "country")()) + val table = UnresolvedRelation(Seq("table")) + val avgAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgAgeWindowExprAlias = Alias(avgAgeWindowExpression, "avg_age")() + val windowPlan1 = Window(Seq(avgAgeWindowExprAlias), partitionSpec, Nil, table) + + val countryPartitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgStateAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("avg_age")), isDistinct = false), + WindowSpecDefinition( + countryPartitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgStateAgeWindowExprAlias = Alias(avgStateAgeWindowExpression, "avg_state_age")() + val windowPlan2 = + Window(Seq(avgStateAgeWindowExprAlias), countryPartitionSpec, Nil, windowPlan1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan2) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}