diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala new file mode 100644 index 000000000..407c2cb3b --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -0,0 +1,528 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLEvalITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test single eval expression with new field") { + val frame = sql(s""" + | source = $testTable | eval col = 1 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val evalProjectList = Seq(UnresolvedStar(None), Alias(Literal(1), "col")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eval expressions with new fields") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1, col2 = 2 | fields name, age + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(Literal(1), "col1")(), Alias(Literal(2), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eval expressions in fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1, col2 = 2 | fields name, age, col1, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, 1, 2), + Row("Hello", 30, 1, 2), + Row("John", 25, 1, 2), + Row("Jane", 20, 1, 2)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2")) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(Literal(1), "col1")(), Alias(Literal(2), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eval expression without fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = "New Field1", col2 = "New Field2" + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "New Field1", "New Field2"), + Row("Hello", 30, "New York", "USA", 2023, 4, "New Field1", "New Field2"), + Row("John", 25, "Ontario", "Canada", 2023, 4, "New Field1", "New Field2"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "New Field1", "New Field2")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val projectList = Seq( + UnresolvedStar(None), + Alias(Literal("New Field1"), "col1")(), + Alias(Literal("New Field2"), "col2")()) + val expectedPlan = Project(seq(UnresolvedStar(None)), Project(projectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test reusing existing fields in eval expressions") { + val frame = sql(s""" + | source = $testTable | eval col1 = state, col2 = country | fields name, age, col1, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA"), + Row("Hello", 30, "New York", "USA"), + Row("John", 25, "Ontario", "Canada"), + Row("Jane", 20, "Quebec", "Canada")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2")) + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("state"), "col1")(), + Alias(UnresolvedAttribute("country"), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test( + "test overriding existing fields: throw exception when specify the new field in fields command") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval age = 40 | eval name = upper(name) | sort name | fields name, age, state + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'name' is ambiguous")) + } + + test("test overriding existing fields: throw exception when specify the new field in where") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval age = abs(age) | where age < 50 + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'age' is ambiguous")) + } + + test( + "test overriding existing fields: throw exception when specify the new field in aggregate expression") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval age = abs(age) | stats avg(age) + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'age' is ambiguous")) + } + + test( + "test overriding existing fields: throw exception when specify the new field in grouping list") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval country = upper(country) | stats avg(age) by country + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'country' is ambiguous")) + } + + test("test override existing fields: the eval field doesn't appear in fields command") { + val frame = sql(s""" + | source = $testTable | eval age = 40, name = upper(name) | sort name | fields state, country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("New York", "USA"), + Row("California", "USA"), + Row("Quebec", "Canada"), + Row("Ontario", "Canada")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val projectList = Seq( + UnresolvedStar(None), + Alias(Literal(40), "age")(), + Alias( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + "name")()) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + global = true, + Project(projectList, table)) + val expectedPlan = + Project(seq(UnresolvedAttribute("state"), UnresolvedAttribute("country")), sortedPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test override existing fields: the new fields not appear in fields command") { + val frame = sql(s""" + | source = $testTable | eval age = 40 | eval name = upper(name) | sort name + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + // In Spark, `name` in eval (as an alias) will be treated as a new column (exprIds are different). + // So if `name` appears in fields command, it will throw ambiguous reference exception. + val expectedResults: Array[Row] = Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 40, "HELLO"), + Row("Jake", 70, "California", "USA", 2023, 4, 40, "JAKE"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 40, "JANE"), + Row("John", 25, "Ontario", "Canada", 2023, 4, 40, "JOHN")) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val evalProjectList1 = Seq(UnresolvedStar(None), Alias(Literal(40), "age")()) + val evalProjectList2 = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + "name")()) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + global = true, + Project(evalProjectList2, Project(evalProjectList1, table))) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortedPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eval commands in fields list") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1 | eval col2 = 2 | fields name, age, col1, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, 1, 2), + Row("Hello", 30, 1, 2), + Row("John", 25, 1, 2), + Row("Jane", 20, 1, 2)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2")) + val evalProjectList1 = Seq(UnresolvedStar(None), Alias(Literal(1), "col1")()) + val evalProjectList2 = Seq(UnresolvedStar(None), Alias(Literal(2), "col2")()) + val expectedPlan = + Project(fieldsProjectList, Project(evalProjectList2, Project(evalProjectList1, table))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eval commands without fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = ln(age) | eval col2 = unix_timestamp('2020-09-16 17:30:00') | sort - col1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, 4.248495242049359, 1600302600), + Row("Hello", 30, "New York", "USA", 2023, 4, 3.4011973816621555, 1600302600), + Row("John", 25, "Ontario", "Canada", 2023, 4, 3.2188758248682006, 1600302600), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 2.995732273553991, 1600302600)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val evalProjectList1 = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), + "col1")()) + val evalProjectList2 = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "unix_timestamp", + seq(Literal("2020-09-16 17:30:00")), + isDistinct = false), + "col2")()) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("col1"), Descending)), + global = true, + Project(evalProjectList2, Project(evalProjectList1, table))) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortedPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex eval commands - case 1") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort age | head 2 | fields name, age, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array(Row("Jane", 20, 2), Row("John", 25, 2)) + assert(results.sameElements(expectedResults)) + } + + test("test complex eval commands - case 2") { + val frame = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | sort + col2 | head 2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30)) + assert(results.sameElements(expectedResults)) + } + + test("test complex eval commands - case 3") { + val frame = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | fields name, age | sort + col2 | head 2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Hello", 30)) + assert(results.sameElements(expectedResults)) + } + + test("test complex eval commands - case 4: execute 1, 2 and 3 together") { + val frame1 = sql(s""" + | source = $testTable | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort age | head 2 | fields name, age, col2 + | """.stripMargin) + val results1: Array[Row] = frame1.collect() + // results1.foreach(println(_)) + val expectedResults1: Array[Row] = Array(Row("Jane", 20, 2), Row("John", 25, 2)) + assert(results1.sameElements(expectedResults1)) + + val frame2 = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | sort + col2 | head 2 + | """.stripMargin) + val results2: Array[Row] = frame2.collect() + // results2.foreach(println(_)) + val expectedResults2: Array[Row] = Array( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30)) + assert(results2.sameElements(expectedResults2)) + + val frame3 = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | fields name, age | sort + col2 | head 2 + | """.stripMargin) + val results3: Array[Row] = frame3.collect() + // results3.foreach(println(_)) + val expectedResults3: Array[Row] = Array(Row("John", 25), Row("Hello", 30)) + assert(results3.sameElements(expectedResults3)) + } + + test("test eval expression used in aggregation") { + val frame = sql(s""" + | source = $testTable | eval col1 = age, col2 = country | stats avg(col1) by col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("age"), "col1")(), + Alias(UnresolvedAttribute("country"), "col2")()) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("col1")), isDistinct = false), + "avg(col1)")(), + Alias(UnresolvedAttribute("col2"), "col2")()) + val aggregatePlan = Aggregate( + Seq(Alias(UnresolvedAttribute("col2"), "col2")()), + aggregateExpressions, + Project(evalProjectList, table)) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex eval expressions with fields command") { + val frame = sql(s""" + | source = $testTable | eval new_name = upper(name) | eval compound_field = concat('Hello ', if(like(new_name, 'HEL%'), 'World', name)) | fields new_name, compound_field + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("JAKE", "Hello Jake"), + Row("HELLO", "Hello World"), + Row("JOHN", "Hello John"), + Row("JANE", "Hello Jane")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test complex eval expressions without fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = "New Field" | eval col2 = upper(lower(col1)) + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "New Field", "NEW FIELD"), + Row("Hello", 30, "New York", "USA", 2023, 4, "New Field", "NEW FIELD"), + Row("John", 25, "Ontario", "Canada", 2023, 4, "New Field", "NEW FIELD"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "New Field", "NEW FIELD")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test depended eval expressions in individual eval command") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1 | eval col2 = col1 | fields name, age, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("Jake", 70, 1), Row("Hello", 30, 1), Row("John", 25, 1), Row("Jane", 20, 1)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"), UnresolvedAttribute("col2")) + val evalProjectList1 = Seq(UnresolvedStar(None), Alias(Literal(1), "col1")()) + val evalProjectList2 = Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("col1"), "col2")()) + val expectedPlan = + Project(fieldsProjectList, Project(evalProjectList2, Project(evalProjectList1, table))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // +--------------------------------+ + // | Below tests are not supported | + // +--------------------------------+ + // Todo: Upgrading spark version to 3.4.0 and above could fix this test. + // https://issues.apache.org/jira/browse/SPARK-27561 + ignore("test lateral eval expressions references - SPARK-27561 required") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1, col2 = col1 | fields name, age, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("Jake", 70, 1), Row("Hello", 30, 1), Row("John", 25, 1), Row("Jane", 20, 1)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"), UnresolvedAttribute("col2")) + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(Literal(1), "col1")(), + Alias(UnresolvedAttribute("col1"), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // Todo excluded fields not support yet + ignore("test single eval expression with excluded fields") { + val frame = sql(s""" + | source = $testTable | eval new_field = "New Field" | fields - age + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", "California", "USA", 2023, 4, "New Field"), + Row("Hello", "New York", "USA", 2023, 4, "New Field"), + Row("John", "Ontario", "Canada", 2023, 4, "New Field"), + Row("Jane", "Quebec", "Canada", 2023, 4, "New Field")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 61ef5b670..1538f43be 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -240,6 +240,23 @@ The next samples of PPL queries are currently supported: - `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` - `source = table | where c = 'test' NOT a > 1 | fields a,b,c` + +**Eval** + +Assumptions: `a`, `b`, `c` are existing fields in `table` + - `source = table | eval f = 1 | fields a,b,c,f` + - `source = table | eval f = 1` (output a,b,c,f fields) + - `source = table | eval n = now() | eval t = unix_timestamp(a) | fields n,t` + - `source = table | eval f = a | where f > 1 | sort f | fields a,b,c | head 5` + - `source = table | eval f = a * 2 | eval h = f * 2 | fields a,f,h` + - `source = table | eval f = a * 2, h = f * 2 | fields a,f,h` (Spark 3.4.0+ required) + - `source = table | eval f = a * 2, h = b | stats avg(f) by h` + +Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous" + - `source = table | eval a = 10 | fields a,b,c` + - `source = table | eval a = a * 2 | stats avg(a)` + - `source = table | eval a = abs(a) | where a > 0` + **Aggregations** - `source = table | stats avg(a) ` - `source = table | where a < 50 | stats avg(c) ` @@ -261,6 +278,7 @@ The next samples of PPL queries are currently supported: - `search` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/search.rst) - `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst) - `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst) + - `eval` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/eval.rst) - `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst) - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, DISTINCT_COUNT, MAX, MIN and SUM aggregation functions) - `sort` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index aac3c3f36..2d0986890 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -36,6 +36,7 @@ commands | statsCommand | sortCommand | headCommand + | evalCommand ; searchCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java index 3f51b595e..35ded8d8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java @@ -13,8 +13,7 @@ /** Argument. */ public class Argument extends UnresolvedExpression { private final String name; - private String argName; - private Literal value; + private final Literal value; public Argument(String name, Literal value) { this.name = name; @@ -27,8 +26,8 @@ public List getChild() { return Arrays.asList(value); } - public String getArgName() { - return argName; + public String getName() { + return name; } public Literal getValue() { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java index 39b42dfe4..a8ec28d0e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java @@ -8,25 +8,25 @@ import com.google.common.collect.ImmutableList; import org.opensearch.sql.ast.AbstractNodeVisitor; -import java.util.ArrayList; import java.util.Collections; import java.util.List; + public class Field extends UnresolvedExpression { - private final UnresolvedExpression field; + private final QualifiedName field; private final List fieldArgs; /** Constructor of Field. */ - public Field(UnresolvedExpression field) { + public Field(QualifiedName field) { this(field, Collections.emptyList()); } /** Constructor of Field. */ - public Field(UnresolvedExpression field, List fieldArgs) { + public Field(QualifiedName field, List fieldArgs) { this.field = field; this.fieldArgs = fieldArgs; } - public UnresolvedExpression getField() { + public QualifiedName getField() { return field; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java index 6cc2b66ff..0a49bbb6c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java @@ -4,21 +4,20 @@ import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Scope; -import org.opensearch.sql.ast.expression.SpanUnit; -import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; /** Logical plan node of correlation , the interface for building the searching sources. */ public class Correlation extends UnresolvedPlan { - private final CorrelationType correlationType; - private final List fieldsList; + private final CorrelationType correlationType; + private final List fieldsList; private final Scope scope; private final FieldsMapping mappingListContext; private UnresolvedPlan child ; - public Correlation(String correlationType, List fieldsList, Scope scope, FieldsMapping mappingListContext) { + public Correlation(String correlationType, List fieldsList, Scope scope, FieldsMapping mappingListContext) { this.correlationType = CorrelationType.valueOf(correlationType); this.fieldsList = fieldsList; this.scope = scope; @@ -45,7 +44,7 @@ public CorrelationType getCorrelationType() { return correlationType; } - public List getFieldsList() { + public List getFieldsList() { return fieldsList; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 04f4320c1..fd8d81e5c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -31,6 +31,7 @@ import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; @@ -60,6 +61,7 @@ import scala.Option; import scala.collection.Seq; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -224,7 +226,22 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); + LogicalPlan child = node.getChild().get(0).accept(this, context); + List aliases = new ArrayList<>(); + List letExpressions = node.getExpressionList(); + for(Let let : letExpressions) { + Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); + aliases.add(alias); + } + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + List expressionList = visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return child; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ce9eea769..9973f4676 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -110,7 +110,9 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { return new Correlation(ctx.correlationType().getText(), ctx.fieldList().fieldExpression().stream() + .map(OpenSearchPPLParser.FieldExpressionContext::qualifiedName) .map(this::internalVisitExpression) + .map(u -> (QualifiedName) u) .collect(Collectors.toList()), Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), expressionBuilder.visit(ctx.scopeClause().value), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 92e9dd458..71abb329f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -141,7 +141,7 @@ public UnresolvedExpression visitWcFieldExpression(OpenSearchPPLParser.WcFieldEx @Override public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ctx) { - return new Field( + return new Field((QualifiedName) visit(ctx.sortFieldExpression().fieldExpression().qualifiedName()), ArgumentFactory.getArgumentList(ctx)); } diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java index 4a942d067..5a20de2d4 100644 --- a/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.common.utils; import static org.junit.Assert.assertEquals; diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala new file mode 100644 index 000000000..772eb050a --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala @@ -0,0 +1,217 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanEvalTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test eval expressions not included in fields expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 1 | fields c", false), + context) + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions included in fields expression") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, c = 1 | fields a, b, c", false), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "c")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions without fields command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | eval a = 1, b = 1", false), context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = + Project(seq(UnresolvedStar(None)), Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 1 | sort - a | fields b", false), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val sortOrder = SortOrder(UnresolvedAttribute("a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions with multiple recursive sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, a = a | sort - a | fields b", false), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(UnresolvedAttribute("a"), "a")()) + val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val sortOrder = SortOrder(UnresolvedAttribute("a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple eval expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = 1, b = 'hello' | eval b = a | sort - b | fields b", + false), + context) + + val evalProjectList1: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal("hello"), "b")()) + val evalProjectList2: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "b")(exprId = ExprId(2), qualifier = Seq.empty)) + val evalProject1 = Project(evalProjectList1, UnresolvedRelation(Seq("t"))) + val evalProject2 = Project(evalProjectList2, evalProject1) + val sortOrder = SortOrder(UnresolvedAttribute("b"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject2) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval expressions - date function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = TIMESTAMP('2020-09-16 17:30:00') | fields a", false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("timestamp", seq(Literal("2020-09-16 17:30:00")), isDistinct = false), + "a")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("a")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval expressions - math function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = RAND() | fields a", false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")( + exprId = ExprId(0), + qualifier = Seq.empty)) + val expectedPlan = Project( + seq(UnresolvedAttribute("a")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval expressions - compound function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = if(like(b, '%Hello%'), 'World', 'Hi') | fields a", + false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "if", + seq( + UnresolvedFunction( + "like", + seq(UnresolvedAttribute("b"), Literal("%Hello%")), + isDistinct = false), + Literal("World"), + Literal("Hi")), + isDistinct = false), + "a")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("a")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + // Todo fields-excluded command not supported + ignore("test eval expressions with fields-excluded command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 2 | fields - b", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")()) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + // Todo fields-included command not supported + ignore("test eval expressions with fields-included command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 2 | fields + b", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")()) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}