diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 4c5b6c255..0f2193ce0 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -12,7 +12,7 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Chen Dai | [dai-chen](https://github.com/dai-chen) | Amazon | | Vamsi Manohar | [vamsi-amazon](https://github.com/vamsi-amazon) | Amazon | | Peng Huo | [penghuo](https://github.com/penghuo) | Amazon | -| Lior Perry | [yangdb](https://github.com/YANG-DB) | Amazon | +| Lior Perry | [YANG-DB](https://github.com/YANG-DB) | Amazon | | Sean Kao | [seankao-az](https://github.com/seankao-az) | Amazon | | Anirudha Jadhav | [anirudha](https://github.com/anirudha) | Amazon | | Kaituo Li | [kaituo](https://github.com/kaituo) | Amazon | diff --git a/docs/PPL-Correlation-command.md b/docs/PPL-Correlation-command.md new file mode 100644 index 000000000..f7ef3e266 --- /dev/null +++ b/docs/PPL-Correlation-command.md @@ -0,0 +1,283 @@ +## PPL Correlation Command + +## Overview + +In the past year OpenSearch Observability & security teams have been busy with many aspects of improving data monitoring and visibility. +The key idea behind our work was to enable the users to dig in their data and emerge the hidden insight within the massive corpus of logs, events and observations. + +One fundamental concept that will help and support this process is the ability to correlate different data sources according to common dimensions and timeframe. +This subject is well documented and described and this RFC will not dive into the necessity of the correlation (appendix will refer to multiple resources related) but for the structuring of the linguistic support for such capability . + +![](https://user-images.githubusercontent.com/48943349/253685892-225e78e1-0942-46b0-8f67-97f9412a1c4c.png) + + +### Problem definition + +In the appendix I’ll add some formal references to the domain of the problem both in Observability / Security, but the main takeaway is that such capability is fundamental in the daily work of such domain experts and SRE’s. +The daily encounters with huge amount of data arriving from different verticals (data-sources) which share the same time-frames but are not synchronized in a formal manner. + +The correlation capability to intersect these different verticals according to the timeframe and the similar dimensions will enrich the data and allow the desired insight to surface. + +**Example** +Lets take the Observability domain for which we have 3 distinct data sources +*- Logs* +*- Metrics* +*- Traces* + +Each datasource may share many common dimensions but to be able to transition from one data-source to another its necessary to be able to correctly correlate them. +According to the semantic naming conventions we know that both logs, traces and metrics + +Lets take the following examples: + +**Log** + +``` +{ + "@timestamp": "2018-07-02T22:23:00.186Z", + "aws": { + "elb": { + "backend": { + "http": { + "response": { + "status_code": 500 + } + }, + "ip": "10.0.0.1", + "port": "80" + }, + ... + "target_port": [ + "10.0.0.1:80" + ], + "target_status_code": [ + "500" + ], + "traceId": "Root=1-58337262-36d228ad5d99923122bbe354", + "type": "http" + } + }, + "cloud": { + "provider": "aws" + }, + "http": { + "request": { + ... + }, + "communication": { + "source": { + "address": "192.168.131.39", + "ip": "192.168.131.39", + "port": 2817 + } + }, + "traceId": "Root=1-58337262-36d228ad5d99923122bbe354" +} +``` + +This is an AWS ELB log arriving from a service residing on aws. +It shows that a `backend.http.response.status_code` was 500 - which is an error. + +This may come up as part of a monitoring process or an alert triggered by some rule. Once this is identified, the next step would be to collect as much data surrounding this event so that an investigation could be done in the most Intelligent and thorough way. + +The most obviously step would be to create a query that brings all data related to that timeframe - but in many case this is too much of a brute force action. + +Data may be too large to analyze and would result in spending most of the time only filtering the none-relevant data instead of actually trying to locate the root cause of the problem. + + +### **Suggest Correlation command** + +The next approach would allow to search in a much fine-grained manner and further simplify the analysis stage. + +Lets review the known facts - we have multiple dimensions that can be used to correlate data data from other sources: + +- **IP** - `"ip": "10.0.0.1" | "ip": "192.168.131.39"` + +- **Port** - `"port": 2817 | ` "target_port": `"10.0.0.1:80"` + +So assuming we have the additional traces / metrics indices available and using the fact that we know our schema structure (see appendix with relevant schema references) we can generate a query for getting all relevant data bearing these dimensions during the same timeframe. + +Here is a snipped of the trace index document that has http information that we would like to correlate with: + +``` +{ + "traceId": "c1d985bd02e1dbb85b444011f19a1ecc", + "spanId": "55a698828fe06a42", + "traceState": [], + "parentSpanId": "", + "name": "mysql", + "kind": "CLIENT", + "@timestamp": "2021-11-13T20:20:39+00:00", + "events": [ + { + "@timestamp": "2021-03-25T17:21:03+00:00", + ... + } + ], + "links": [ + { + "traceId": "c1d985bd02e1dbb85b444011f19a1ecc", + "spanId": "55a698828fe06a42w2", + }, + "droppedAttributesCount": 0 + } + ], + "resource": { + "service@name": "database", + "telemetry@sdk@name": "opentelemetry", + "host@hostname": "ip-172-31-10-8.us-west-2.compute.internal" + }, + "status": { + ... + }, + "attributes": { + "http": { + "user_agent": { + "original": "Mozilla/5.0" + }, + "network": { + ... + } + }, + "request": { + ... + } + }, + "response": { + "status_code": "200", + "body": { + "size": 500 + } + }, + "client": { + "server": { + "socket": { + "address": "192.168.0.1", + "domain": "example.com", + "port": 80 + }, + "address": "192.168.0.1", + "port": 80 + }, + "resend_count": 0, + "url": { + "full": "http://example.com" + } + }, + "server": { + "route": "/index", + "address": "192.168.0.2", + "port": 8080, + "socket": { + ... + }, + "client": { + ... + } + }, + "url": { + ... + } + } + } + } +} +``` + +In the above document we can see both the `traceId` and the http’s client/server `ip` that can be correlated with the elb logs to better understand the system’s behaviour and condition . + + +### New Correlation Query Command + +Here is the new command that would allow this type of investigation : + +`source alb_logs, traces | where alb_logs.ip="10.0.0.1" AND alb_logs.cloud.provider="aws"| ` +`correlate exact fields(traceId, ip) scope(@timestamp, 1D) mapping(alb_logs.ip = traces.attributes.http.server.address, alb_logs.traceId = traces.traceId ) ` + +Lets break this down a bit: + +`1. source alb_logs, traces` allows to select all the data-sources that will be correlated to one another + +`2. where ip="10.0.0.1" AND cloud.provider="aws"` predicate clause constraints the scope of the search corpus + +`3. correlate exact fields(traceId, ip)` express the correlation operation on the following list of field : + +`- ip` has an explicit filter condition so this will be propagated into the correlation condition for all the data-sources +`- traceId` has no explicit filter so the correlation will only match same traceId’s from all the data-sources + +The fields names indicate the logical meaning the function within the correlation command, the actual join condition will take the mapping statement described bellow. + +The term `exact` means that the correlation statements will require all the fields to match in order to fulfill the query statement. + +Other alternative for this can be `approximate` that will attempt to match on a best case scenario and will not reject rows with partially match. + + +### Addressing different field mapping + +In cases where the same logical field (such as `ip` ) may have different mapping within several data-sources, the explicit mapping field path is expected. + +The next syntax will extend the correlation conditions to allow matching different field names with similar logical meaning +`alb_logs.ip = traces.attributes.http.server.address, alb_logs.traceId = traces.traceId ` + +It is expected that for each `field` that participates in the correlation join, there should be a relevant `mapping` statement that includes all the tables that should be joined by this correlation command. + +**Example****:** +In our case there are 2 sources : `alb_logs, traces` +There are 2 fields: `traceId, ip` +These are 2 mapping statements : `alb_logs.ip = traces.attributes.http.server.address, alb_logs.traceId = traces.traceId` + + +### Scoping the correlation timeframes + +In order to simplify the work that has to be done by the execution engine (driver) the scope statement was added to explicitly direct the join query on the time it should scope for this search. + +`scope(@timestamp, 1D)` in this example, the scope of the search should be focused on a daily basis so that correlations appearing in the same day should be grouped together. This scoping mechanism simplifies and allows better control over results and allows incremental search resolution base on the user’s needs. + +***Diagram*** +These are the correlation conditions that explicitly state how the sources are going to be joined. +[Image: Screenshot 2023-10-06 at 12.23.59 PM.png]* * * + +## Supporting Drivers + +The new correlation command is actually a ‘hidden’ join command therefore the only following PPL drivers support this command: + +- [ppl-spark](https://github.com/opensearch-project/opensearch-spark/tree/main/ppl-spark-integration) + In this driver the `correlation` command will be directly translated into the appropriate Catalyst Join logical plan + +**Example:** +*`source alb_logs, traces, metrics | where ip="10.0.0.1" AND cloud.provider="aws"| correlate exact on (ip, port) scope(@timestamp, 2018-07-02T22:23:00, 1 D)`* + +**Logical Plan:** + +``` +'Project [*] ++- 'Join Inner, ('ip && 'port) + :- 'Filter (('ip === "10.0.0.1" && 'cloud.provider === "aws") && inTimeScope('@timestamp, "2018-07-02T22:23:00", "1 D")) + +- 'UnresolvedRelation [alb_logs] + +- 'Join Inner, ('ip && 'port) + :- 'Filter (('ip === "10.0.0.1" && 'cloud.provider === "aws") && inTimeScope('@timestamp, "2018-07-02T22:23:00", "1 D")) + +- 'UnresolvedRelation [traces] + +- 'Filter (('ip === "10.0.0.1" && 'cloud.provider === "aws") && inTimeScope('@timestamp, "2018-07-02T22:23:00", "1 D")) + +- 'UnresolvedRelation [metrics] +``` + +Catalyst engine will optimize this query according to the most efficient join ordering. + +* * * + +## Appendix + +* Correlation concepts + * https://github.com/opensearch-project/sql/issues/1583 + * https://github.com/opensearch-project/dashboards-observability/issues?q=is%3Aopen+is%3Aissue+label%3Ametrics +* Observability Correlation + * https://opentelemetry.io/docs/specs/otel/trace/semantic_conventions/ + * https://github.com/opensearch-project/dashboards-observability/wiki/Observability-Future-Vision#data-correlation +* Security Correlation + * [OpenSearch new correlation engine](https://opensearch.org/docs/latest/security-analytics/usage/correlation-graph/) + * [ocsf](https://github.com/ocsf/) +* Simple schema + * [correlation use cases](https://github.com/opensearch-project/dashboards-observability/wiki/Observability-Future-Vision#data-correlation) + * [correlation mapping metadata](https://github.com/opensearch-project/opensearch-catalog/tree/main/docs/schema) + +![](https://user-images.githubusercontent.com/48943349/274153824-9c6008e0-fdaf-434f-8e5d-4347cee66ac4.png) + diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index e3fb467e6..40cb5c201 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -122,8 +122,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val index = flint.describeIndex(testIndex) index shouldBe defined - val optionJson = compact(render( - parse(index.get.metadata().getContent) \ "_meta" \ "options")) + val optionJson = compact(render(parse(index.get.metadata().getContent) \ "_meta" \ "options")) optionJson should matchJson(""" | { | "auto_refresh": "true", @@ -321,8 +320,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { |""".stripMargin) query.queryExecution.executedPlan should - useFlintSparkSkippingFileIndex( - hasIndexFilter(col("year") === 2023)) + useFlintSparkSkippingFileIndex(hasIndexFilter(col("year") === 2023)) } test("should not rewrite original query if filtering condition has disjunction") { @@ -388,8 +386,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { // Prepare test table val testTable = "spark_catalog.default.data_type_table" val testIndex = getSkippingIndexName(testTable) - sql( - s""" + sql(s""" | CREATE TABLE $testTable | ( | boolean_col BOOLEAN, @@ -408,8 +405,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | ) | USING PARQUET |""".stripMargin) - sql( - s""" + sql(s""" | INSERT INTO $testTable | VALUES ( | TRUE, @@ -449,8 +445,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val index = flint.describeIndex(testIndex) index shouldBe defined - index.get.metadata().getContent should matchJson( - s"""{ + index.get.metadata().getContent should matchJson(s"""{ | "_meta": { | "name": "flint_spark_catalog_default_data_type_table_skipping_index", | "version": "${current()}", @@ -587,8 +582,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { test("can build skipping index for varchar and char and rewrite applicable query") { val testTable = "spark_catalog.default.varchar_char_table" val testIndex = getSkippingIndexName(testTable) - sql( - s""" + sql(s""" | CREATE TABLE $testTable | ( | varchar_col VARCHAR(20), @@ -596,8 +590,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | ) | USING PARQUET |""".stripMargin) - sql( - s""" + sql(s""" | INSERT INTO $testTable | VALUES ( | "sample varchar", @@ -613,8 +606,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { .create() flint.refreshIndex(testIndex, FULL) - val query = sql( - s""" + val query = sql(s""" | SELECT varchar_col, char_col | FROM $testTable | WHERE varchar_col = "sample varchar" AND char_col = "sample char" @@ -624,8 +616,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val paddedChar = "sample char".padTo(20, ' ') checkAnswer(query, Row("sample varchar", paddedChar)) query.queryExecution.executedPlan should - useFlintSparkSkippingFileIndex(hasIndexFilter( - col("varchar_col") === "sample varchar" && col("char_col") === paddedChar)) + useFlintSparkSkippingFileIndex( + hasIndexFilter(col("varchar_col") === "sample varchar" && col("char_col") === paddedChar)) flint.deleteIndex(testIndex) } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala similarity index 99% rename from integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLITSuite.scala rename to integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index 9dea04872..8f1d1bd1f 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest -class FlintSparkPPLITSuite +class FlintSparkPPLBasicITSuite extends QueryTest with LogicalPlanTestUtils with FlintPPLSuite diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala new file mode 100644 index 000000000..61564546e --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -0,0 +1,751 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, Literal, Multiply, Or, SortOrder} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.JoinHint.NONE +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCorrelationITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable1 = "spark_catalog.default.flint_ppl_test1" + private val testTable2 = "spark_catalog.default.flint_ppl_test2" + private val testTable3 = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test tables + sql(s""" + | CREATE TABLE $testTable1 + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + sql(s""" + | CREATE TABLE $testTable2 + | ( + | name STRING, + | occupation STRING, + | country STRING, + | salary INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable1 + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jim', 27, 'B.C', 'Canada'), + | ('Peter', 57, 'B.C', 'Canada'), + | ('Rick', 70, 'B.C', 'Canada'), + | ('David', 40, 'Washington', 'USA'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable2 + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'Engineer', 'England' , 100000), + | ('Hello', 'Artist', 'USA', 70000), + | ('John', 'Doctor', 'Canada', 120000), + | ('David', 'Doctor', 'USA', 120000), + | ('David', 'Unemployed', 'Canada', 0), + | ('Jane', 'Scientist', 'Canada', 90000) + | """.stripMargin) + sql(s""" + | CREATE TABLE $testTable3 + | ( + | name STRING, + | country STRING, + | hobby STRING, + | language STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable3 + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'USA', 'Fishing', 'English'), + | ('Hello', 'USA', 'Painting', 'English'), + | ('John', 'Canada', 'Reading', 'French'), + | ('Jim', 'Canada', 'Hiking', 'English'), + | ('Peter', 'Canada', 'Gaming', 'English'), + | ('Rick', 'USA', 'Swimming', 'English'), + | ('David', 'USA', 'Gardening', 'English'), + | ('Jane', 'Canada', 'Singing', 'French') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create failing ppl correlation query - due to mismatch fields to mappings test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate exact fields(name, country) scope(month, 1W) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + } + assert( + thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") + } + test("create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate exact fields(name, country) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + } + assert( + thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") + } + + test( + "create failing ppl correlation query - due to mismatch correlation self type and source amount test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate self fields(name, country) scope(month, 1W) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + } + assert( + thrown.getMessage === "Correlation command with `inner` type must have exactly on source table ") + } + + test( + "create failing ppl correlation query - due to mismatch correlation exact type and source amount test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable1 | correlate approximate fields(name) scope(month, 1W) mapping($testTable1.name = $testTable1.inner_name) + | """.stripMargin) + } + assert( + thrown.getMessage === "Correlation command with `approximate` type must at least two different source tables ") + } + + test( + "create ppl correlation exact query with filters and two tables correlating on a single field test") { + val joinQuery = + s""" + | SELECT a.name, a.age, a.state, a.country, b.occupation, b.salary + | FROM $testTable1 AS a + | JOIN $testTable2 AS b + | ON a.name = b.name + | WHERE a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4 + |""".stripMargin + + val result = spark.sql(joinQuery) + result.show() + + val frame = sql(s""" + | source = $testTable1, $testTable2| where year = 2023 AND month = 4 | correlate exact fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row( + "Jake", + 70, + "California", + "USA", + 2023, + 4, + "Jake", + "Engineer", + "England", + 100000, + 2023, + 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Unemployed", "Canada", 0, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + val filter2Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation approximate query with filters and two tables correlating on a single field test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate approximate fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Unemployed", "Canada", 0, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row( + "Jake", + 70, + "California", + "USA", + 2023, + 4, + "Jake", + "Engineer", + "England", + 100000, + 2023, + 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4), + Row("Jim", 27, "B.C", "Canada", 2023, 4, null, null, null, null, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("Peter", 57, "B.C", "Canada", 2023, 4, null, null, null, null, null, null), + Row("Rick", 70, "B.C", "Canada", 2023, 4, null, null, null, null, null, null)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(table1, table2, FullOuter, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test( + "create ppl correlation approximate query with two tables correlating on a single field and not scope test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate approximate fields(name) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Unemployed", "Canada", 0, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row( + "Jake", + 70, + "California", + "USA", + 2023, + 4, + "Jake", + "Engineer", + "England", + 100000, + 2023, + 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4), + Row("Jim", 27, "B.C", "Canada", 2023, 4, null, null, null, null, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("Peter", 57, "B.C", "Canada", 2023, 4, null, null, null, null, null, null), + Row("Rick", 70, "B.C", "Canada", 2023, 4, null, null, null, null, null, null)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(table1, table2, FullOuter, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation query with with filters and two tables correlating on a two fields test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| where year = 2023 AND month = 4 | correlate exact fields(name, country) scope(month, 1W) + | mapping($testTable1.name = $testTable2.name, $testTable1.country = $testTable2.country) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + val filter2Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + And( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name")), + EqualTo( + UnresolvedAttribute(s"$testTable1.country"), + UnresolvedAttribute(s"$testTable2.country"))) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation query with two tables correlating on a two fields and disjoint filters test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| where year = 2023 AND month = 4 AND $testTable2.salary > 100000 | correlate exact fields(name, country) scope(month, 1W) + | mapping($testTable1.name = $testTable2.name, $testTable1.country = $testTable2.country) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + val filter2Expr = And( + And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))), + GreaterThan(UnresolvedAttribute(s"$testTable2.salary"), Literal(100000))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + And( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name")), + EqualTo( + UnresolvedAttribute(s"$testTable1.country"), + UnresolvedAttribute(s"$testTable2.country"))) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation (exact) query with two tables correlating by name and group by avg salary by age span (10 years bucket) test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate exact fields(name) scope(month, 1W) + | mapping($testTable1.name = $testTable2.name) | + | stats avg(salary) by span(age, 10) as age_span + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(100000.0, 70), Row(105000.0, 20), Row(60000.0, 40), Row(70000.0, 30)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(table1, table2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(span), Seq(aggregateExpressions, span), joinPlan) + // Add the projection + val expectedPlan = Project(star, aggregatePlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation (exact) query with two tables correlating by name and group by avg salary by age span (10 years bucket) and country test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate exact fields(name) scope(month, 1W) + | mapping($testTable1.name = $testTable2.name) | + | stats avg(salary) by span(age, 10) as age_span, $testTable2.country + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(120000.0, "USA", 40), + Row(0.0, "Canada", 40), + Row(70000.0, "USA", 30), + Row(100000.0, "England", 70), + Row(105000.0, "Canada", 20)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(table1, table2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute(s"$testTable2.country") + val countryAlias = Alias(countryField, s"$testTable2.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + // Add the projection + val expectedPlan = Project(star, aggregatePlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation (exact) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) with country filter test") { + val frame = sql(s""" + | source = $testTable1, $testTable2 | where country = 'USA' OR country = 'England' | + | correlate exact fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name) | + | stats avg(salary) by span(age, 10) as age_span, $testTable2.country + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(120000.0, "USA", 40), Row(100000.0, "England", 70), Row(70000.0, "USA", 30)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val filter2Expr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute(s"$testTable2.country") + val countryAlias = Alias(countryField, s"$testTable2.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + // Add the projection + val expectedPlan = Project(star, aggregatePlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test( + "create ppl correlation (exact) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) with country filter without scope test") { + val frame = sql(s""" + | source = $testTable1, $testTable2 | where country = 'USA' OR country = 'England' | + | correlate exact fields(name) mapping($testTable1.name = $testTable2.name) | + | stats avg(salary) by span(age, 10) as age_span, $testTable2.country + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(120000.0, "USA", 40), Row(100000.0, "England", 70), Row(70000.0, "USA", 30)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val filter2Expr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute(s"$testTable2.country") + val countryAlias = Alias(countryField, s"$testTable2.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + // Add the projection + val expectedPlan = Project(star, aggregatePlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl correlation (approximate) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate approximate fields(name, country) scope(month, 1W) + | mapping($testTable1.name = $testTable2.name, $testTable1.country = $testTable2.country) | + | stats avg(salary) by span(age, 10) as age_span, $testTable2.country | sort - age_span | head 5 + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(70000.0, "Canada", 70L), + Row(100000.0, "England", 70L), + Row(95000.0, "USA", 70L), + Row(70000.0, "Canada", 50L), + Row(95000.0, "USA", 40L)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](2)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define join condition - according to the correlation (approximate) type + val joinCondition = + Or( + EqualTo( + UnresolvedAttribute(s"$testTable1.name"), + UnresolvedAttribute(s"$testTable2.name")), + EqualTo( + UnresolvedAttribute(s"$testTable1.country"), + UnresolvedAttribute(s"$testTable2.country"))) + + // Create Join plan + val joinPlan = Join(table1, table2, FullOuter, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute(s"$testTable2.country") + val countryAlias = Alias(countryField, s"$testTable2.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + // sort by age_span + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)), + global = true, + aggregatePlan) + + val limitPlan = Limit(Literal(5), sortedPlan) + val expectedPlan = Project(star, limitPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index fb46ce4de..62ff50fb6 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -158,7 +158,6 @@ class FlintSparkPPLFiltersITSuite // Define the expected results val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) // Compare the results - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -306,7 +305,6 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - test("create ppl simple name literal not equal filter query with two fields result test") { val frame = sql(s""" | source = $testTable name!='Jake' | fields name, age @@ -333,70 +331,6 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - - test("create ppl simple avg age by span of interval of 10 years query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test( - "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 2) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) - val limitPlan = Limit(Literal(2), aggregatePlan) - val expectedPlan = Project(star, limitPlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - /** * | age_span | country | average_age | * |:---------|:--------|:------------| diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index e74aed30e..b1c988b28 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -36,6 +36,14 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; +//CORRELATION KEYWORDS +CORRELATE: 'CORRELATE'; +SELF: 'SELF'; +EXACT: 'EXACT'; +APPROXIMATE: 'APPROXIMATE'; +SCOPE: 'SCOPE'; +MAPPING: 'MAPPING'; + // COMMAND ASSIST KEYWORDS AS: 'AS'; BY: 'BY'; @@ -262,6 +270,7 @@ DAYOFWEEK: 'DAYOFWEEK'; DAYOFYEAR: 'DAYOFYEAR'; DAY_OF_MONTH: 'DAY_OF_MONTH'; DAY_OF_WEEK: 'DAY_OF_WEEK'; +DURATION: 'DURATION'; EXTRACT: 'EXTRACT'; FROM_DAYS: 'FROM_DAYS'; FROM_UNIXTIME: 'FROM_UNIXTIME'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 69f560f25..4b4e64c1a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -33,6 +33,7 @@ pplCommands commands : whereCommand + | correlateCommand | fieldsCommand | renameCommand | statsCommand @@ -61,12 +62,34 @@ describeCommand ; showDataSourcesCommand - : SHOW DATASOURCES - ; + : SHOW DATASOURCES + ; whereCommand - : WHERE logicalExpression - ; + : WHERE logicalExpression + ; + +correlateCommand + : CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS (scopeClause)? mappingList + ; + +correlationType + : SELF + | EXACT + | APPROXIMATE + ; + +scopeClause + : SCOPE LT_PRTHS fieldExpression COMMA value = literalValue (unit = timespanUnit)? RT_PRTHS + ; + +mappingList + : MAPPING LT_PRTHS ( mappingClause (COMMA mappingClause)* ) RT_PRTHS + ; + +mappingClause + : left = qualifiedName comparisonOperator right = qualifiedName # mappingCompareExpr + ; fieldsCommand : FIELDS (PLUS | MINUS)? fieldList @@ -820,6 +843,7 @@ keywordsCanBeId | SHOW | FROM | WHERE + | CORRELATE | FIELDS | RENAME | STATS diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 9a2e88484..e3d0c6a2b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -16,6 +16,7 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; @@ -35,6 +36,7 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; @@ -94,6 +96,14 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitCorrelation(Correlation node, C context) { + return visitChildren(node, context); + } + + public T visitCorrelationMapping(FieldsMapping node, C context) { + return visitChildren(node, context); + } + public T visitProject(Project node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java index f19de2a05..f783aabb7 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java @@ -11,28 +11,12 @@ import java.util.List; /** Expression node of logic AND. */ -public class And extends UnresolvedExpression { - private UnresolvedExpression left; - private UnresolvedExpression right; +public class And extends BinaryExpression { public And(UnresolvedExpression left, UnresolvedExpression right) { - this.left = left; - this.right = right; + super(left,right); } - - @Override - public List getChild() { - return Arrays.asList(left, right); - } - - public UnresolvedExpression getLeft() { - return left; - } - - public UnresolvedExpression getRight() { - return right; - } - + @Override public R accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitAnd(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/BinaryExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/BinaryExpression.java new file mode 100644 index 000000000..a50a153a0 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/BinaryExpression.java @@ -0,0 +1,29 @@ +package org.opensearch.sql.ast.expression; + +import java.util.Arrays; +import java.util.List; + +public abstract class BinaryExpression extends UnresolvedExpression { + private UnresolvedExpression left; + private UnresolvedExpression right; + + public BinaryExpression(UnresolvedExpression left, UnresolvedExpression right) { + this.left = left; + this.right = right; + } + + @Override + public List getChild() { + return Arrays.asList(left, right); + } + + public UnresolvedExpression getLeft() { + return left; + } + + public UnresolvedExpression getRight() { + return right; + } + +} + diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java index 7c77fae1f..39b42dfe4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import org.opensearch.sql.ast.AbstractNodeVisitor; +import java.util.ArrayList; import java.util.Collections; import java.util.List; public class Field extends UnresolvedExpression { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java new file mode 100644 index 000000000..37d31b822 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java @@ -0,0 +1,22 @@ +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +public class FieldsMapping extends UnresolvedExpression { + + private final List fieldsMappingList; + + public FieldsMapping(List fieldsMappingList) { + this.fieldsMappingList = fieldsMappingList; + } + public List getChild() { + return fieldsMappingList; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCorrelationMapping(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java index 65e1a2e6d..d76cda695 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java @@ -12,28 +12,10 @@ /** Expression node of the logic OR. */ -public class Or extends UnresolvedExpression { - private UnresolvedExpression left; - private UnresolvedExpression right; - +public class Or extends BinaryExpression { public Or(UnresolvedExpression left, UnresolvedExpression right) { - this.left = left; - this.right = right; + super(left,right); } - - @Override - public List getChild() { - return Arrays.asList(left, right); - } - - public UnresolvedExpression getLeft() { - return left; - } - - public UnresolvedExpression getRight() { - return right; - } - @Override public R accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitOr(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java new file mode 100644 index 000000000..3fbe53cd2 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java @@ -0,0 +1,8 @@ +package org.opensearch.sql.ast.expression; + +/** Scope expression node. Params include field expression and the scope value. */ +public class Scope extends Span { + public Scope(UnresolvedExpression field, UnresolvedExpression value, SpanUnit unit) { + super(field, value, unit); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java index 9368a6363..9f618a067 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java @@ -12,28 +12,14 @@ /** Expression node of the logic XOR. */ -public class Xor extends UnresolvedExpression { +public class Xor extends BinaryExpression { private UnresolvedExpression left; private UnresolvedExpression right; public Xor(UnresolvedExpression left, UnresolvedExpression right) { - this.left = left; - this.right = right; + super(left,right); } - - @Override - public List getChild() { - return Arrays.asList(left, right); - } - - public UnresolvedExpression getLeft() { - return left; - } - - public UnresolvedExpression getRight() { - return right; - } - + @Override public R accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitXor(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java new file mode 100644 index 000000000..6cc2b66ff --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java @@ -0,0 +1,67 @@ +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.Scope; +import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +/** Logical plan node of correlation , the interface for building the searching sources. */ + +public class Correlation extends UnresolvedPlan { + private final CorrelationType correlationType; + private final List fieldsList; + private final Scope scope; + private final FieldsMapping mappingListContext; + private UnresolvedPlan child ; + public Correlation(String correlationType, List fieldsList, Scope scope, FieldsMapping mappingListContext) { + this.correlationType = CorrelationType.valueOf(correlationType); + this.fieldsList = fieldsList; + this.scope = scope; + this.mappingListContext = mappingListContext; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCorrelation(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public Correlation attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + public CorrelationType getCorrelationType() { + return correlationType; + } + + public List getFieldsList() { + return fieldsList; + } + + public Scope getScope() { + return scope; + } + + public FieldsMapping getMappingListContext() { + return mappingListContext; + } + + public enum CorrelationType { + self, + exact, + approximate + } + + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 7e21ac9a9..66ed765a3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -5,18 +5,24 @@ package org.opensearch.sql.ppl; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Union; +import scala.collection.Iterator; import scala.collection.Seq; +import java.util.Collection; +import java.util.List; +import java.util.Optional; import java.util.Stack; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.collection.JavaConverters.asJavaCollection; import static scala.collection.JavaConverters.asScalaBuffer; /** @@ -27,6 +33,10 @@ public class CatalystPlanContext { * Catalyst evolving logical plan **/ private Stack planBranches = new Stack<>(); + /** + * The current traversal context the visitor is going threw + */ + private Stack planTraversalContext = new Stack<>(); /** * NamedExpression contextual parameters @@ -37,7 +47,11 @@ public class CatalystPlanContext { * Grouping NamedExpression contextual parameters **/ private final Stack groupingParseExpressions = new Stack<>(); - + + public Stack getPlanBranches() { + return planBranches; + } + public LogicalPlan getPlan() { if (this.planBranches.size() == 1) { return planBranches.peek(); @@ -46,47 +60,151 @@ public LogicalPlan getPlan() { return new Union(asScalaBuffer(this.planBranches), true, true); } + /** + * get the current traversals visitor context + * + * @return + */ + public Stack traversalContext() { + return planTraversalContext; + } + public Stack getNamedParseExpressions() { return namedParseExpressions; } + public Optional popNamedParseExpressions() { + return namedParseExpressions.isEmpty() ? Optional.empty() : Optional.of(namedParseExpressions.pop()); + } + public Stack getGroupingParseExpressions() { return groupingParseExpressions; } /** - * append context with evolving plan + * append plan with evolving plans branches * * @param plan + * @return */ - public void with(LogicalPlan plan) { - this.planBranches.push(plan); + public LogicalPlan with(LogicalPlan plan) { + return this.planBranches.push(plan); } - public LogicalPlan plan(Function transformFunction) { - this.planBranches.replaceAll(transformFunction::apply); + /** + * append plans collection with evolving plans branches + * + * @param plans + * @return + */ + public LogicalPlan withAll(Collection plans) { + this.planBranches.addAll(plans); return getPlan(); } - - /** + + /** + * reduce all plans with the given reduce function + * + * @param transformFunction + * @return + */ + public LogicalPlan reduce(BiFunction transformFunction) { + Collection logicalPlans = asJavaCollection(retainAllPlans(p -> p)); + // in case it is a self join - single table - apply the same plan + if (logicalPlans.size() < 2) { + return with(logicalPlans.stream().map(plan -> { + planTraversalContext.push(plan); + LogicalPlan result = transformFunction.apply(plan, plan); + planTraversalContext.pop(); + return result; + }).findAny() + .orElse(getPlan())); + } + // in case there are multiple join tables - reduce the tables + return with(logicalPlans.stream().reduce((left, right) -> { + planTraversalContext.push(left); + planTraversalContext.push(right); + LogicalPlan result = transformFunction.apply(left, right); + planTraversalContext.pop(); + planTraversalContext.pop(); + return result; + }).orElse(getPlan())); + } + + /** + * apply for each plan with the given function + * + * @param transformFunction + * @return + */ + public LogicalPlan apply(Function transformFunction) { + return withAll(asJavaCollection(retainAllPlans(p -> p)).stream().map(p -> { + planTraversalContext.push(p); + LogicalPlan result = transformFunction.apply(p); + planTraversalContext.pop(); + return result; + }).collect(Collectors.toList())); + } + + /** + * retain all logical plans branches + * + * @return + */ + public Seq retainAllPlans(Function transformFunction) { + Seq plans = seq(getPlanBranches().stream().map(transformFunction).collect(Collectors.toList())); + getPlanBranches().retainAll(emptyList()); + return plans; + } + + /** * retain all expressions and clear expression stack + * * @return */ public Seq retainAllNamedParseExpressions(Function transformFunction) { Seq aggregateExpressions = seq(getNamedParseExpressions().stream() - .map(transformFunction::apply).collect(Collectors.toList())); + .map(transformFunction).collect(Collectors.toList())); getNamedParseExpressions().retainAll(emptyList()); return aggregateExpressions; } /** * retain all aggregate expressions and clear expression stack + * * @return */ public Seq retainAllGroupingNamedParseExpressions(Function transformFunction) { Seq aggregateExpressions = seq(getGroupingParseExpressions().stream() - .map(transformFunction::apply).collect(Collectors.toList())); + .map(transformFunction).collect(Collectors.toList())); getGroupingParseExpressions().retainAll(emptyList()); return aggregateExpressions; } + + public static List findRelation(Stack plan) { + return plan.stream() + .map(CatalystPlanContext::findRelation) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toList()); + } + + public static Optional findRelation(LogicalPlan plan) { + // Check if the current node is an UnresolvedRelation + if (plan instanceof UnresolvedRelation) { + return Optional.of((UnresolvedRelation) plan); + } + + // Traverse the children of the current node + Iterator children = plan.children().iterator(); + while (children.hasNext()) { + Optional result = findRelation(children.next()); + if (result.isPresent()) { + return result; + } + } + + // Return null if no UnresolvedRelation is found + return Optional.empty(); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index ff7e54e22..6d14db328 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -23,15 +23,18 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.BinaryExpression; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.WindowFunction; @@ -40,6 +43,7 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; @@ -57,12 +61,17 @@ import java.util.List; import java.util.Objects; +import java.util.Optional; +import java.util.function.BiFunction; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; +import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; /** @@ -95,19 +104,39 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { - node.getTableName().forEach(t -> { - // Resolving the qualifiedName which is composed of a datasource.schema.table - context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)); - }); + node.getTableName().forEach(t -> + // Resolving the qualifiedName which is composed of a datasource.schema.table + context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) + ); return context.getPlan(); } @Override public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); - Expression conditionExpression = visitExpression(node.getCondition(), context); - Expression innerConditionExpression = context.getNamedParseExpressions().pop(); - return context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression, p)); + return context.apply(p -> { + Expression conditionExpression = visitExpression(node.getCondition(), context); + Optional innerConditionExpression = context.popNamedParseExpressions(); + return innerConditionExpression.map(expression -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression.get(), p)).orElse(null); + }); + } + + @Override + public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + context.reduce((left,right) -> { + visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); + Seq fields = context.retainAllNamedParseExpressions(e -> e); + if(!Objects.isNull(node.getScope())) { + // scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit) + expressionAnalyzer.visitSpan(node.getScope(), context); + context.popNamedParseExpressions().get(); + } + expressionAnalyzer.visitCorrelationMapping(node.getMappingListContext(), context); + Seq mapping = context.retainAllNamedParseExpressions(e -> e); + return join(node.getCorrelationType(), fields, mapping, left, right); + }); + return context.getPlan(); } @Override @@ -130,11 +159,11 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex // build the aggregation logical step return extractedAggregation(context); } - + private static LogicalPlan extractedAggregation(CatalystPlanContext context) { Seq groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p); Seq aggregateExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - return context.plan(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); + return context.apply(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); } @Override @@ -153,7 +182,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { if (!projectList.isEmpty()) { Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step - child = context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); } if (node.hasArgument()) { Argument argument = node.getArgExprList().get(0); @@ -161,19 +190,19 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { } return child; } - + @Override public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); visitFieldList(node.getSortList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); - return context.plan(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); + return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); } @Override public LogicalPlan visitHead(Head node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); - return context.plan(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( + return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( node.getSize(), DataTypes.IntegerType), p)); } @@ -242,53 +271,67 @@ public Expression visitLiteral(Literal node, CatalystPlanContext context) { translate(node.getValue(), node.getType()), translate(node.getType()))); } - @Override - public Expression visitAnd(And node, CatalystPlanContext context) { + /** + * generic binary (And, Or, Xor , ...) arithmetic expression resolver + * @param node + * @param transformer + * @param context + * @return + */ + public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { node.getLeft().accept(this, context); - Expression left = (Expression) context.getNamedParseExpressions().pop(); + Optional left = context.popNamedParseExpressions(); node.getRight().accept(this, context); - Expression right = (Expression) context.getNamedParseExpressions().pop(); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)); + Optional right = context.popNamedParseExpressions(); + if(left.isPresent() && right.isPresent()) { + return transformer.apply(left.get(),right.get()); + } else if(left.isPresent()) { + return context.getNamedParseExpressions().push(left.get()); + } else if(right.isPresent()) { + return context.getNamedParseExpressions().push(right.get()); + } + return null; + + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); } @Override public Expression visitOr(Or node, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Expression left = (Expression) context.getNamedParseExpressions().pop(); - node.getRight().accept(this, context); - Expression right = (Expression) context.getNamedParseExpressions().pop(); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)); + return visitBinaryArithmetic(node, + (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); } @Override public Expression visitXor(Xor node, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Expression left = (Expression) context.getNamedParseExpressions().pop(); - node.getRight().accept(this, context); - Expression right = (Expression) context.getNamedParseExpressions().pop(); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)); + return visitBinaryArithmetic(node, + (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); } @Override public Expression visitNot(Not node, CatalystPlanContext context) { node.getExpression().accept(this, context); - Expression arg = (Expression) context.getNamedParseExpressions().pop(); - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(arg)); + Optional arg = context.popNamedParseExpressions(); + return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); } @Override public Expression visitSpan(Span node, CatalystPlanContext context) { node.getField().accept(this, context); - Expression field = (Expression) context.getNamedParseExpressions().pop(); + Expression field = (Expression) context.popNamedParseExpressions().get(); node.getValue().accept(this, context); - Expression value = (Expression) context.getNamedParseExpressions().pop(); + Expression value = (Expression) context.popNamedParseExpressions().get(); return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); } @Override public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { node.getField().accept(this, context); - Expression arg = (Expression) context.getNamedParseExpressions().pop(); + Expression arg = (Expression) context.popNamedParseExpressions().get(); Expression aggregator = AggregatorTranslator.aggregator(node, arg); return context.getNamedParseExpressions().push(aggregator); } @@ -296,16 +339,32 @@ public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanCon @Override public Expression visitCompare(Compare node, CatalystPlanContext context) { analyze(node.getLeft(), context); - Expression left = (Expression) context.getNamedParseExpressions().pop(); + Optional left = context.popNamedParseExpressions(); analyze(node.getRight(), context); - Expression right = (Expression) context.getNamedParseExpressions().pop(); - Predicate comparator = ComparatorTransformer.comparator(node, left, right); - return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + return null; } @Override - public Expression visitField(Field node, CatalystPlanContext context) { - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getField().toString()))); + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + List relation = findRelation(context.traversalContext()); + if (!relation.isEmpty()) { + Optional resolveField = resolveField(relation, node); + return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) + .orElse(null); + } + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); } @Override @@ -321,7 +380,7 @@ public Expression visitAllFields(AllFields node, CatalystPlanContext context) { @Override public Expression visitAlias(Alias node, CatalystPlanContext context) { node.getDelegated().accept(this, context); - Expression arg = context.getNamedParseExpressions().pop(); + Expression arg = context.popNamedParseExpressions().get(); return context.getNamedParseExpressions().push( org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, node.getAlias() != null ? node.getAlias() : node.getName(), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 1b26255f9..a810ea180 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -15,14 +15,18 @@ import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Scope; +import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; @@ -41,9 +45,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; +import static java.util.Collections.emptyList; + /** Class of building the AST. Refines the visit path and build the AST nodes */ public class AstBuilder extends OpenSearchPPLParserBaseVisitor { @@ -99,6 +106,21 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext return new Filter(internalVisitExpression(ctx.logicalExpression())); } + @Override + public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { + return new Correlation(ctx.correlationType().getText(), + ctx.fieldList().fieldExpression().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()), + Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), + expressionBuilder.visit(ctx.scopeClause().value), + SpanUnit.of(Objects.isNull(ctx.scopeClause().unit) ? "" : ctx.scopeClause().unit.getText())), + Objects.isNull(ctx.mappingList()) ? new FieldsMapping(emptyList()) : new FieldsMapping(ctx.mappingList() + .mappingClause().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()))); + } + /** Fields command. */ @Override public UnresolvedPlan visitFieldsCommand(OpenSearchPPLParser.FieldsCommandContext ctx) { @@ -149,7 +171,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext getTextInQuery(groupCtx), internalVisitExpression(groupCtx))) .collect(Collectors.toList())) - .orElse(Collections.emptyList()); + .orElse(emptyList()); UnresolvedExpression span = Optional.ofNullable(ctx.statsByClause()) @@ -160,7 +182,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext Aggregation aggregation = new Aggregation( aggListBuilder.build(), - Collections.emptyList(), + emptyList(), groupList, span, ArgumentFactory.getArgumentList(ctx)); @@ -254,7 +276,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { List groupList = - ctx.byClause() == null ? Collections.emptyList() : getGroupByList(ctx.byClause()); + ctx.byClause() == null ? emptyList() : getGroupByList(ctx.byClause()); return new RareTopN( RareTopN.CommandType.TOP, ArgumentFactory.getArgumentList(ctx), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index e7d723afd..3344cd7c2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -62,6 +62,16 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor fields, Seq mapping, LogicalPlan left, LogicalPlan right) { + //create a join statement - which will replace all the different plans with a single plan which contains the joined plans + switch (correlationType) { + case self: + //expecting exactly one source relation + if (!left.equals(right)) + throw new IllegalStateException("Correlation command with `inner` type must have exactly on source table "); + break; + case exact: + //expecting at least two source relations + if (left.equals(right)) + throw new IllegalStateException("Correlation command with `exact` type must at least two different source tables "); + break; + case approximate: + if (left.equals(right)) + throw new IllegalStateException("Correlation command with `approximate` type must at least two different source tables "); + //expecting at least two source relations + break; + } + + if (fields.isEmpty()) + throw new IllegalStateException("Correlation command was called with `empty` correlation fields "); + + if (mapping.isEmpty()) + throw new IllegalStateException("Correlation command was called with `empty` correlation mappings "); + + if (mapping.seq().size() != fields.seq().size()) + throw new IllegalStateException("Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes "); + + // Define join condition + Expression joinCondition = buildJoinCondition(seqAsJavaListConverter(fields).asJava(), seqAsJavaListConverter(mapping).asJava(), correlationType); + // Define join step instead on the multiple query branches + return new Join(left, right, getType(correlationType), Option.apply(joinCondition), JoinHint.NONE()); + } + + static Expression buildJoinCondition(List fields, List mapping, Correlation.CorrelationType correlationType) { + switch (correlationType) { + case self: + //expecting exactly one source relation - mapping will be used to set the inner join counterpart + break; + case exact: + //expecting at least two source relations + return mapping.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + case approximate: + return mapping.stream().reduce(org.apache.spark.sql.catalyst.expressions.Or::new).orElse(null); + } + return mapping.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + static JoinType getType(Correlation.CorrelationType correlationType) { + switch (correlationType) { + case self: + case exact: + return Inner$.MODULE$; + case approximate: + return FullOuter$.MODULE$; + } + return Inner$.MODULE$; + } +} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java new file mode 100644 index 000000000..b402aaae5 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -0,0 +1,33 @@ +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.opensearch.sql.ast.expression.QualifiedName; + +import java.util.List; +import java.util.Optional; + +public interface RelationUtils { + /** + * attempt resolving if the field is relating to the given relation + * if name doesnt contain table prefix - add the current relation prefix to the fields name - returns true + * if name does contain table prefix - verify field's table name corresponds to the current contextual relation + * + * @param relations + * @param node + * @return + */ + static Optional resolveField(List relations, QualifiedName node) { + return relations.stream() + .map(rel -> { + //if name doesnt contain table prefix - add the current relation prefix to the fields name - returns true + if (node.getPrefix().isEmpty()) +// return Optional.of(QualifiedName.of(relation.tableName(), node.getParts().toArray(new String[]{}))); + return Optional.of(node); + if (node.getPrefix().get().toString().equals(rel.tableName())) + return Optional.of(node); + return Optional.empty(); + }).filter(Optional::isPresent) + .map(field -> (QualifiedName) field.get()) + .findFirst(); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintPPLSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintPPLSuite.scala new file mode 100644 index 000000000..450f21c63 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintPPLSuite.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +trait FlintPPLSuite extends SharedSparkSession { + override protected def sparkConf = { + val conf = new SparkConf() + .set("spark.ui.enabled", "false") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set("spark.sql.extensions", classOf[FlintPPLSparkExtensions].getName) + conf + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala deleted file mode 100644 index 8434c5bf1..000000000 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.spark.ppl - -import org.junit.Assert.assertEquals -import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} -import org.scalatest.matchers.should.Matchers - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, Like, Literal, SortOrder, UnixTimestamp} -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ - -class PPLLogicalAdvancedTranslatorTestSuite - extends SparkFunSuite - with LogicalPlanTestUtils - with Matchers { - - private val planTrnasformer = new CatalystQueryPlanVisitor() - private val pplParser = new PPLSyntaxParser() - - ignore("Find What are the average prices for different types of properties") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan(pplParser, "source = housing_properties | stats avg(price) by property_type", false), - context) - // SQL: SELECT property_type, AVG(price) FROM housing_properties GROUP BY property_type - val table = UnresolvedRelation(Seq("housing_properties")) - - val avgPrice = Alias(Average(UnresolvedAttribute("price")), "avg(price)")() - val propertyType = UnresolvedAttribute("property_type") - val grouped = Aggregate(Seq(propertyType), Seq(propertyType, avgPrice), table) - - val projectList = Seq( - UnresolvedAttribute("property_type"), - Alias(Average(UnresolvedAttribute("price")), "avg(price)")()) - val expectedPlan = Project(projectList, grouped) - - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } - - ignore( - "Find the top 10 most expensive properties in California, including their addresses, prices, and cities") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = housing_properties | where state = `CA` | fields address, price, city | sort - price | head 10", - false), - context) - // SQL: SELECT address, price, city FROM housing_properties WHERE state = 'CA' ORDER BY price DESC LIMIT 10 - - // Constructing the expected Catalyst Logical Plan - val table = UnresolvedRelation(Seq("housing_properties")) - val filter = Filter(EqualTo(UnresolvedAttribute("state"), Literal("CA")), table) - val projectList = Seq( - UnresolvedAttribute("address"), - UnresolvedAttribute("price"), - UnresolvedAttribute("city")) - val projected = Project(projectList, filter) - val sortOrder = SortOrder(UnresolvedAttribute("price"), Descending) :: Nil - val sorted = Sort(sortOrder, true, projected) - val limited = Limit(Literal(10), sorted) - val finalProjectList = Seq( - UnresolvedAttribute("address"), - UnresolvedAttribute("price"), - UnresolvedAttribute("city")) - - val expectedPlan = Project(finalProjectList, limited) - - // Assert that the generated plan is as expected - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore("Find the average price per unit of land space for properties in different cities") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = housing_properties | where land_space > 0 | eval price_per_land_unit = price / land_space | stats avg(price_per_land_unit) by city", - false), - context) - // SQL: SELECT city, AVG(price / land_space) AS avg_price_per_land_unit FROM housing_properties WHERE land_space > 0 GROUP BY city - val table = UnresolvedRelation(Seq("housing_properties")) - val filter = Filter(GreaterThan(UnresolvedAttribute("land_space"), Literal(0)), table) - val expression = AggregateExpression( - Average(Divide(UnresolvedAttribute("price"), UnresolvedAttribute("land_space"))), - mode = Complete, - isDistinct = false) - val aggregateExpr = Alias(expression, "avg_price_per_land_unit")() - val groupBy = Aggregate( - groupingExpressions = Seq(UnresolvedAttribute("city")), - aggregateExpressions = Seq(aggregateExpr), - filter) - - val expectedPlan = Project( - projectList = - Seq(UnresolvedAttribute("city"), UnresolvedAttribute("avg_price_per_land_unit")), - groupBy) - // Continue with your test... - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore("Find the houses posted in the last month, how many are still for sale") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "search source=housing_properties | where listing_age >= 0 | where listing_age < 30 | stats count() by property_status", - false), - context) - // SQL: SELECT property_status, COUNT(*) FROM housing_properties WHERE listing_age >= 0 AND listing_age < 30 GROUP BY property_status; - - val filter = Filter( - LessThan(UnresolvedAttribute("listing_age"), Literal(30)), - Filter( - GreaterThanOrEqual(UnresolvedAttribute("listing_age"), Literal(0)), - UnresolvedRelation(Seq("housing_properties")))) - - val expression = AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) - - val aggregateExpressions = Seq(Alias(expression, "count")()) - - val groupByAttributes = Seq(UnresolvedAttribute("property_status")) - val expectedPlan = Aggregate(groupByAttributes, aggregateExpressions, filter) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore( - "Find all the houses listed by agency Compass in decreasing price order. Also provide only price, address and agency name information.") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = housing_properties | where match( agency_name , `Compass` ) | fields address , agency_name , price | sort - price ", - false), - context) - // SQL: SELECT address, agency_name, price FROM housing_properties WHERE agency_name LIKE '%Compass%' ORDER BY price DESC - - val projectList = Seq( - UnresolvedAttribute("address"), - UnresolvedAttribute("agency_name"), - UnresolvedAttribute("price")) - val table = UnresolvedRelation(Seq("housing_properties")) - - val filterCondition = Like(UnresolvedAttribute("agency_name"), Literal("%Compass%"), '\\') - val filter = Filter(filterCondition, table) - - val sortOrder = Seq(SortOrder(UnresolvedAttribute("price"), Descending)) - val sort = Sort(sortOrder, true, filter) - - val expectedPlan = Project(projectList, sort) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore("Find details of properties owned by Zillow with at least 3 bedrooms and 2 bathrooms") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = housing_properties | where is_owned_by_zillow = 1 and bedroom_number >= 3 and bathroom_number >= 2 | fields address, price, city, listing_age", - false), - context) - // SQL:SELECT address, price, city, listing_age FROM housing_properties WHERE is_owned_by_zillow = 1 AND bedroom_number >= 3 AND bathroom_number >= 2; - val projectList = Seq( - UnresolvedAttribute("address"), - UnresolvedAttribute("price"), - UnresolvedAttribute("city"), - UnresolvedAttribute("listing_age")) - - val filterCondition = And( - And( - EqualTo(UnresolvedAttribute("is_owned_by_zillow"), Literal(1)), - GreaterThanOrEqual(UnresolvedAttribute("bedroom_number"), Literal(3))), - GreaterThanOrEqual(UnresolvedAttribute("bathroom_number"), Literal(2))) - - val expectedPlan = Project( - projectList, - Filter(filterCondition, UnresolvedRelation(TableIdentifier("housing_properties")))) - // Add to your unit test - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore("Find which cities in WA state have the largest number of houses for sale") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = housing_properties | where property_status = 'FOR_SALE' and state = 'WA' | stats count() as count by city | sort -count | head", - false), - context) - // SQL : SELECT city, COUNT(*) as count FROM housing_properties WHERE property_status = 'FOR_SALE' AND state = 'WA' GROUP BY city ORDER BY count DESC LIMIT 10; - val aggregateExpressions = Seq( - Alias( - AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false), - "count")()) - val groupByAttributes = Seq(UnresolvedAttribute("city")) - - val filterCondition = And( - EqualTo(UnresolvedAttribute("property_status"), Literal("FOR_SALE")), - EqualTo(UnresolvedAttribute("state"), Literal("WA"))) - - val expectedPlan = Limit( - Literal(10), - Sort( - Seq(SortOrder(UnresolvedAttribute("count"), Descending)), - true, - Aggregate( - groupByAttributes, - aggregateExpressions, - Filter(filterCondition, UnresolvedRelation(TableIdentifier("housing_properties")))))) - - // Add to your unit test - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore("Find the top 5 referrers for the '/' path in apache access logs") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan(pplParser, "source = access_logs | where path = `/` | top 5 referer", false), - context) - /* - SQL: SELECT referer, COUNT(*) as count - FROM access_logs - WHERE path = '/' GROUP BY referer ORDER BY count DESC LIMIT 5; - */ - val aggregateExpressions = Seq( - Alias( - AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false), - "count")()) - val groupByAttributes = Seq(UnresolvedAttribute("referer")) - val filterCondition = EqualTo(UnresolvedAttribute("path"), Literal("/")) - val expectedPlan = Limit( - Literal(5), - Sort( - Seq(SortOrder(UnresolvedAttribute("count"), Descending)), - true, - Aggregate( - groupByAttributes, - aggregateExpressions, - Filter(filterCondition, UnresolvedRelation(TableIdentifier("access_logs")))))) - - // Add to your unit test - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } - - ignore( - "Find access paths by status code. How many error responses (status code 400 or higher) are there for each access path in the Apache access logs") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = access_logs | where status >= 400 | stats count() by path, status", - false), - context) - /* - SQL: SELECT path, status, COUNT(*) as count - FROM access_logs - WHERE status >=400 GROUP BY path, status; - */ - val aggregateExpressions = Seq( - Alias( - AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false), - "count")()) - val groupByAttributes = Seq(UnresolvedAttribute("path"), UnresolvedAttribute("status")) - - val filterCondition = GreaterThanOrEqual(UnresolvedAttribute("status"), Literal(400)) - - val expectedPlan = Aggregate( - groupByAttributes, - aggregateExpressions, - Filter(filterCondition, UnresolvedRelation(TableIdentifier("access_logs")))) - - // Add to your unit test - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore("Find max size of nginx access requests for every 15min") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = access_logs | stats max(size) by span( request_time , 15m) ", - false), - context) - - // SQL: SELECT MAX(size) AS max_size, floor(request_time / 900) AS time_span FROM access_logs GROUP BY time_span; - val aggregateExpressions = Seq(Alias( - AggregateExpression(Max(UnresolvedAttribute("size")), mode = Complete, isDistinct = false), - "max_size")()) - val groupByAttributes = - Seq(Alias(Floor(Divide(UnresolvedAttribute("request_time"), Literal(900))), "time_span")()) - - val expectedPlan = Aggregate( - groupByAttributes, - aggregateExpressions ++ groupByAttributes, - UnresolvedRelation(TableIdentifier("access_logs"))) - - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } - - ignore("Find nginx logs with non 2xx status code and url containing 'products'") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = sso_logs-nginx-* | where match(http.url, 'products') and http.response.status_code >= `300`", - false), - context) - // SQL : SELECT * FROM `sso_logs-nginx-*` WHERE http.url LIKE '%products%' AND http.response.status_code >= 300; - val aggregateExpressions = Seq(Alias( - AggregateExpression(Max(UnresolvedAttribute("size")), mode = Complete, isDistinct = false), - "max_size")()) - val groupByAttributes = - Seq(Alias(Floor(Divide(UnresolvedAttribute("request_time"), Literal(900))), "time_span")()) - - val expectedPlan = Aggregate( - groupByAttributes, - aggregateExpressions, - UnresolvedRelation(TableIdentifier("access_logs"))) - - // Add to your unit test - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } - - ignore( - "Find What are the details (URL, response status code, timestamp, source address) of events in the nginx logs where the response status code is 400 or higher") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = sso_logs-nginx-* | where http.response.status_code >= `400` | fields http.url, http.response.status_code, @timestamp, communication.source.address", - false), - context) - // SQL : SELECT http.url, http.response.status_code, @timestamp, communication.source.address FROM sso_logs-nginx-* WHERE http.response.status_code >= 400; - val projectList = Seq( - UnresolvedAttribute("http.url"), - UnresolvedAttribute("http.response.status_code"), - UnresolvedAttribute("@timestamp"), - UnresolvedAttribute("communication.source.address")) - - val filterCondition = - GreaterThanOrEqual(UnresolvedAttribute("http.response.status_code"), Literal(400)) - - val expectedPlan = Project( - projectList, - Filter(filterCondition, UnresolvedRelation(TableIdentifier("sso_logs-nginx-*")))) - - // Add to your unit test - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } - - ignore( - "Find What are the average and max http response sizes, grouped by request method, for access events in the nginx logs") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = sso_logs-nginx-* | where event.name = `access` | stats avg(http.response.bytes), max(http.response.bytes) by http.request.method", - false), - context) - // SQL : SELECT AVG(http.response.bytes) AS avg_size, MAX(http.response.bytes) AS max_size, http.request.method FROM sso_logs-nginx-* WHERE event.name = 'access' GROUP BY http.request.method; - val aggregateExpressions = Seq( - Alias( - AggregateExpression( - Average(UnresolvedAttribute("http.response.bytes")), - mode = Complete, - isDistinct = false), - "avg_size")(), - Alias( - AggregateExpression( - Max(UnresolvedAttribute("http.response.bytes")), - mode = Complete, - isDistinct = false), - "max_size")()) - val groupByAttributes = Seq(UnresolvedAttribute("http.request.method")) - - val expectedPlan = Aggregate( - groupByAttributes, - aggregateExpressions ++ groupByAttributes, - Filter( - EqualTo(UnresolvedAttribute("event.name"), Literal("access")), - UnresolvedRelation(TableIdentifier("sso_logs-nginx-*")))) - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - } - - ignore( - "Find flights from which carrier has the longest average delay for flights over 6k miles") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = opensearch_dashboards_sample_data_flights | where DistanceMiles > 6000 | stats avg(FlightDelayMin) by Carrier | sort -`avg(FlightDelayMin)` | head 1", - false), - context) - // SQL: SELECT AVG(FlightDelayMin) AS avg_delay, Carrier FROM opensearch_dashboards_sample_data_flights WHERE DistanceMiles > 6000 GROUP BY Carrier ORDER BY avg_delay DESC LIMIT 1; - val aggregateExpressions = Seq( - Alias( - AggregateExpression( - Average(UnresolvedAttribute("FlightDelayMin")), - mode = Complete, - isDistinct = false), - "avg_delay")()) - val groupByAttributes = Seq(UnresolvedAttribute("Carrier")) - - val expectedPlan = Limit( - Literal(1), - Sort( - Seq(SortOrder(UnresolvedAttribute("avg_delay"), Descending)), - true, - Aggregate( - groupByAttributes, - aggregateExpressions ++ groupByAttributes, - Filter( - GreaterThan(UnresolvedAttribute("DistanceMiles"), Literal(6000)), - UnresolvedRelation(TableIdentifier("opensearch_dashboards_sample_data_flights")))))) - - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } - - ignore("Find What's the average ram usage of windows machines over time aggregated by 1 week") { - val context = new CatalystPlanContext - val logPlan = planTrnasformer.visit( - plan( - pplParser, - "source = opensearch_dashboards_sample_data_logs | where match(machine.os, 'win') | stats avg(machine.ram) by span(timestamp,1w)", - false), - context) - // SQL : SELECT AVG(machine.ram) AS avg_ram, floor(extract(epoch from timestamp) / 604800) - // AS week_span FROM opensearch_dashboards_sample_data_logs WHERE machine.os LIKE '%win%' GROUP BY week_span - val aggregateExpressions = Seq( - Alias( - AggregateExpression( - Average(UnresolvedAttribute("machine.ram")), - mode = Complete, - isDistinct = false), - "avg_ram")()) - val groupByAttributes = Seq( - Alias( - Floor( - Divide( - UnixTimestamp(UnresolvedAttribute("timestamp"), Literal("yyyy-MM-dd HH:mm:ss")), - Literal(604800))), - "week_span")()) - - val expectedPlan = Aggregate( - groupByAttributes, - aggregateExpressions ++ groupByAttributes, - Filter( - Like(UnresolvedAttribute("machine.os"), Literal("%win%"), '\\'), - UnresolvedRelation(TableIdentifier("opensearch_dashboards_sample_data_logs")))) - - assertEquals(expectedPlan, context.getPlan) - assertEquals(logPlan, "???") - - } -} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 955aac3f5..87f7e5b28 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} import org.apache.spark.sql.catalyst.plans.logical._ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite @@ -298,5 +298,77 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logPlan)) } + test("create ppl query count status amount by day window and group by status test") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats sum(status) by span(@timestamp, 1d) as status_count_by_day, status | head 100", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusAmount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "status_count_by_day")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(statusAmount), isDistinct = false), + "sum(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + table) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + test( + "create ppl query count only error (status >= 400) status amount by day window and group by status test") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | where status >= 400 | stats sum(status) by span(@timestamp, 1d) as status_count_by_day, status | head 100", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val statusAlias = Alias(UnresolvedAttribute("status"), "status")() + val statusField = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val filterExpr = GreaterThanOrEqual(statusField, Literal(400)) + val filterPlan = Filter(filterExpr, table) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "status_count_by_day")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(statusField), isDistinct = false), "sum(status)")() + val aggregatePlan = Aggregate( + Seq(statusAlias, windowExpression), + Seq(aggregateExpressions, statusAlias, windowExpression), + filterPlan) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } } diff --git a/spark-sql-integration/README.md b/spark-sql-integration/README.md new file mode 100644 index 000000000..07bf46406 --- /dev/null +++ b/spark-sql-integration/README.md @@ -0,0 +1,109 @@ +# Spark SQL Application + +This application execute sql query and store the result in OpenSearch index in following format +``` +"stepId":"", +"applicationId":"" +"schema": "json blob", +"result": "json blob" +``` + +## Prerequisites + ++ Spark 3.3.1 ++ Scala 2.12.15 ++ flint-spark-integration + +## Usage + +To use this application, you can run Spark with Flint extension: + +``` +./bin/spark-submit \ + --class org.opensearch.sql.SQLJob \ + --jars \ + sql-job.jar \ + \ + \ + \ + \ + \ + \ + \ +``` + +## Result Specifications + +Following example shows how the result is written to OpenSearch index after query execution. + +Let's assume sql query result is +``` ++------+------+ +|Letter|Number| ++------+------+ +|A |1 | +|B |2 | +|C |3 | ++------+------+ +``` +OpenSearch index document will look like +```json +{ + "_index" : ".query_execution_result", + "_id" : "A2WOsYgBMUoqCqlDJHrn", + "_score" : 1.0, + "_source" : { + "result" : [ + "{'Letter':'A','Number':1}", + "{'Letter':'B','Number':2}", + "{'Letter':'C','Number':3}" + ], + "schema" : [ + "{'column_name':'Letter','data_type':'string'}", + "{'column_name':'Number','data_type':'integer'}" + ], + "stepId" : "s-JZSB1139WIVU", + "applicationId" : "application_1687726870985_0003" + } +} +``` + +## Build + +To build and run this application with Spark, you can run: + +``` +sbt clean sparkSqlApplicationCosmetic/publishM2 +``` + +## Test + +To run tests, you can use: + +``` +sbt test +``` + +## Scalastyle + +To check code with scalastyle, you can run: + +``` +sbt scalastyle +``` + +## Code of Conduct + +This project has adopted an [Open Source Code of Conduct](../CODE_OF_CONDUCT.md). + +## Security + +If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public GitHub issue. + +## License + +See the [LICENSE](../LICENSE.txt) file for our project's licensing. We will ask you to confirm the licensing of your contribution. + +## Copyright + +Copyright OpenSearch Contributors. See [NOTICE](../NOTICE) for details. \ No newline at end of file