From 8843a872dc66878ac6155d4d0fb6cd9a76dd3219 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 1 Aug 2024 22:59:21 +0800 Subject: [PATCH 1/2] Support more builtin functions by adding a name mapping Signed-off-by: Lantao Jin --- .../flint/spark/FlintSparkSuite.scala | 22 ++ .../FlintSparkPPLBuiltinFunctionITSuite.scala | 290 +++++++++++------- .../ppl/utils/BuiltinFunctionTranslator.java | 32 ++ ...ggregationQueriesTranslatorTestSuite.scala | 31 +- ...lPlanBasicQueriesTranslatorTestSuite.scala | 27 +- ...ogicalPlanFiltersTranslatorTestSuite.scala | 29 +- ...PlanMathFunctionsTranslatorTestSuite.scala | 65 +++- ...anStringFunctionsTranslatorTestSuite.scala | 29 +- ...PlanTimeFunctionsTranslatorTestSuite.scala | 177 ++++++++++- 9 files changed, 517 insertions(+), 185 deletions(-) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index fbb2f89bd..7e0b68376 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -188,6 +188,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createNullableStateCountryTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada'), + | (null, 10, null, 'Canada') + | """.stripMargin) + } + protected def createOccupationTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala index 127b29295..c9bf8a926 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala @@ -5,9 +5,11 @@ package org.opensearch.flint.spark.ppl +import java.sql.{Date, Time, Timestamp} + import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{EqualTo, GreaterThan, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -22,12 +24,14 @@ class FlintSparkPPLBuiltinFunctionITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val testNullTable = "spark_catalog.default.flint_ppl_test_null" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createNullableStateCountryTable(testNullTable) } protected override def afterEach(): Unit = { @@ -44,17 +48,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable name=concat('He', 'llo') | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("name"), @@ -62,7 +61,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -71,15 +69,11 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable name=concat('Hello', state) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array.empty assert(results.sameElements(expectedResults)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("name"), @@ -90,7 +84,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -99,17 +92,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where length(name) = 5 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -117,7 +105,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -126,17 +113,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where leNgTh(name) = 5 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -144,7 +126,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -153,17 +134,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where lower(name) = "hello" | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("lower", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -171,7 +147,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -180,17 +155,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where upper(name) = upper("hello") | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -198,7 +168,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -207,17 +176,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where substring(name, 2, 2) = "el" | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -228,7 +192,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -237,17 +200,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable | where like(name, '_ello%') | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val likeFunction = UnresolvedFunction( "like", @@ -257,7 +215,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(likeFunction, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -266,17 +223,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where replace(name, 'o', ' ') = "Hell " | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -287,7 +239,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -296,17 +247,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where trim(replace(name, 'o', ' ')) = "Hell" | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -321,7 +267,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -330,17 +275,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where age = abs(-30) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("age"), @@ -348,7 +288,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -357,17 +296,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where abs(age) = 30 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("abs", seq(UnresolvedAttribute("age")), isDistinct = false), @@ -375,7 +309,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -384,17 +317,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where age = ceil(29.7) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("age"), @@ -402,7 +330,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -411,17 +338,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where age = floor(30.4) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("age"), @@ -429,7 +351,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -438,17 +359,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where ln(age) > 4 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Jake", 70)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = GreaterThan( UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), @@ -456,7 +372,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -465,17 +380,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where mod(age, 10) = 0 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("Jane", 20)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("mod", seq(UnresolvedAttribute("age"), Literal(10)), isDistinct = false), @@ -483,7 +393,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -492,17 +401,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where sqrt(pow(age, 2)) = 30.0 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -517,7 +421,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -526,18 +429,13 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = GreaterThan( UnresolvedFunction( @@ -548,7 +446,183 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test arithmetic operators (+ - * / %)") { + val frame = sql(s""" + | source = $testTable | where (sqrt(pow(age, 2)) + sqrt(pow(age, 2)) / 1 - sqrt(pow(age, 2)) * 1) % 25.0 = 0 | fields name, age + | """.stripMargin) // equals age + age - age + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("John", 25)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test boolean operators (= != < <= > >=)") { + val frame = sql(s""" + | source = $testTable | eval a = age = 30, b = age != 70, c = 30 < age, d = 30 <= age, e = 30 > age, f = 30 >= age | fields age, a, b, c, d, e, f + | """.stripMargin) // equals age + age - age + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(70, false, false, true, true, false, false), + Row(30, true, true, false, true, false, true), + Row(25, false, true, false, false, true, true), + Row(20, false, true, false, false, true, true)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test boolean condition functions - isnull isnotnull ifnull nullif") { + val frameIsNull = sql(s""" + | source = $testNullTable | where isnull(name) | fields age + | """.stripMargin) + + val results1: Array[Row] = frameIsNull.collect() + val expectedResults1: Array[Row] = Array(Row(10)) + assert(results1.sameElements(expectedResults1)) + + val frameIsNotNull = sql(s""" + | source = $testNullTable | where isnotnull(name) | fields name + | """.stripMargin) + + val results2: Array[Row] = frameIsNotNull.collect() + val expectedResults2: Array[Row] = Array(Row("John"), Row("Jane"), Row("Jake"), Row("Hello")) + assert(results2.sameElements(expectedResults2)) + + val frameIfNull = sql(s""" + | source = $testNullTable | eval new_name = ifnull(name, "Unknown") | fields new_name, age + | """.stripMargin) + + val results3: Array[Row] = frameIfNull.collect() + val expectedResults3: Array[Row] = Array( + Row("John", 25), + Row("Jane", 20), + Row("Unknown", 10), + Row("Jake", 70), + Row("Hello", 30)) + assert(results3.sameElements(expectedResults3)) + + val frameNullIf = sql(s""" + | source = $testNullTable | eval new_age = nullif(age, 20) | fields name, new_age + | """.stripMargin) + + val results4: Array[Row] = frameNullIf.collect() + val expectedResults4: Array[Row] = + Array(Row("John", 25), Row("Jane", null), Row(null, 10), Row("Jake", 70), Row("Hello", 30)) + assert(results4.sameElements(expectedResults4)) + } + + test("test typeof function") { + val frame = sql(s""" + | source = $testNullTable | eval tdate = typeof(DATE('2008-04-14')), tint = typeof(1), tnow = typeof(now()), tcol = typeof(age) | fields tdate, tint, tnow, tcol | head 1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("date", "int", "timestamp", "int")) + assert(results.sameElements(expectedResults)) + } + + test("test the builtin functions which required additional name mapping") { + val frame = sql(s""" + | source = $testNullTable + | | eval a = DAY_OF_WEEK(DATE('2020-08-26')) + | | eval b = DAY_OF_MONTH(DATE('2020-08-26')) + | | eval c = DAY_OF_YEAR(DATE('2020-08-26')) + | | eval d = WEEK_OF_YEAR(DATE('2020-08-26')) + | | eval e = WEEK(DATE('2020-08-26')) + | | eval f = MONTH_OF_YEAR(DATE('2020-08-26')) + | | eval g = HOUR_OF_DAY(DATE('2020-08-26')) + | | eval h = MINUTE_OF_HOUR(DATE('2020-08-26')) + | | eval i = SECOND_OF_MINUTE(DATE('2020-08-26')) + | | eval j = SUBDATE(DATE('2020-08-26'), 1) + | | eval k = ADDDATE(DATE('2020-08-26'), 1) + | | eval l = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')) + | | eval m = DATEDIFF(ADDDATE(LOCALTIME(), 1), LOCALTIME()) + | | fields a, b, c, d, e, f, g, h, i, j, k, l, m + | | head 1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = { + Array( + Row( + 4, + 26, + 239, + 35, + 35, + 8, + 0, + 0, + 0, + Date.valueOf("2020-08-25"), + Date.valueOf("2020-08-27"), + 1, + 1)) + } + assert(results.sameElements(expectedResults)) + } + + test("not all arguments could work in builtin functions") { + intercept[AnalysisException](sql(s""" + | source = $testTable | eval a = WEEK(DATE('2008-02-20'), 1) + | """.stripMargin)) + intercept[AnalysisException](sql(s""" + | source = $testTable | eval a = SUBDATE(DATE('2020-08-26'), INTERVAL 31 DAY) + | """.stripMargin)) + intercept[AnalysisException](sql(s""" + | source = $testTable | eval a = ADDDATE(DATE('2020-08-26'), INTERVAL 1 HOUR) + | """.stripMargin)) + } + + // Todo + // +---------------------------------------+ + // | Below tests are not supported (cast) | + // +---------------------------------------+ + ignore("test cast to string") { + val frame = sql(s""" + | source = $testNullTable | eval cbool = CAST(true as string), cint = CAST(1 as string), cdate = CAST(CAST('2012-08-07' as date) as string) | fields cbool, cint, cdate + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(true, 1, "2012-08-07")) + assert(results.sameElements(expectedResults)) + } + + ignore("test cast to number") { + val frame = sql(s""" + | source = $testNullTable | eval cbool = CAST(true as int), cstring = CAST('1' as int) | fields cbool, cstring + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1, 1)) + assert(results.sameElements(expectedResults)) + } + + ignore("test cast to date") { + val frame = sql(s""" + | source = $testNullTable | eval cdate = CAST('2012-08-07' as date), ctime = CAST('01:01:01' as time), ctimestamp = CAST('2012-08-07 01:01:01' as timestamp) | fields cdate, ctime, ctimestamp + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row( + Date.valueOf("2012-08-07"), + Time.valueOf("01:01:01"), + Timestamp.valueOf("2012-08-07 01:01:01"))) + assert(results.sameElements(expectedResults)) + } + + ignore("test can be chained") { + val frame = sql(s""" + | source = $testNullTable | eval cbool = CAST(CAST(true as string) as boolean) | fields cbool + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(true)) + assert(results.sameElements(expectedResults)) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java index 0d57fea20..6f17b1247 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -5,18 +5,49 @@ package org.opensearch.sql.ppl.utils; +import com.google.common.collect.ImmutableMap; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.Expression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; public interface BuiltinFunctionTranslator { + /** + * The name mapping between PPL builtin functions to Spark builtin functions. + */ + static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING = new ImmutableMap.Builder() + // arithmetic operators + .put(BuiltinFunctionName.ADD.name().toLowerCase(Locale.ROOT), "+") + .put(BuiltinFunctionName.SUBTRACT.name().toLowerCase(Locale.ROOT), "-") + .put(BuiltinFunctionName.MULTIPLY.name().toLowerCase(Locale.ROOT), "*") + .put(BuiltinFunctionName.DIVIDE.name().toLowerCase(Locale.ROOT), "/") + .put(BuiltinFunctionName.MODULUS.name().toLowerCase(Locale.ROOT), "%") + // time functions + .put(BuiltinFunctionName.DAY_OF_WEEK.name().toLowerCase(Locale.ROOT), "dayofweek") + .put(BuiltinFunctionName.DAY_OF_MONTH.name().toLowerCase(Locale.ROOT), "dayofmonth") + .put(BuiltinFunctionName.DAY_OF_YEAR.name().toLowerCase(Locale.ROOT), "dayofyear") + .put(BuiltinFunctionName.WEEK_OF_YEAR.name().toLowerCase(Locale.ROOT), "weekofyear") + .put(BuiltinFunctionName.WEEK.name().toLowerCase(Locale.ROOT), "weekofyear") + .put(BuiltinFunctionName.MONTH_OF_YEAR.name().toLowerCase(Locale.ROOT), "month") + .put(BuiltinFunctionName.HOUR_OF_DAY.name().toLowerCase(Locale.ROOT), "hour") + .put(BuiltinFunctionName.MINUTE_OF_HOUR.name().toLowerCase(Locale.ROOT), "minute") + .put(BuiltinFunctionName.SECOND_OF_MINUTE.name().toLowerCase(Locale.ROOT), "second") + .put(BuiltinFunctionName.SUBDATE.name().toLowerCase(Locale.ROOT), "date_sub") // only maps subdate(date, days) + .put(BuiltinFunctionName.ADDDATE.name().toLowerCase(Locale.ROOT), "date_add") // only maps adddate(date, days) + .put(BuiltinFunctionName.DATEDIFF.name().toLowerCase(Locale.ROOT), "datediff") + .put(BuiltinFunctionName.LOCALTIME.name().toLowerCase(Locale.ROOT), "localtimestamp") + // condition functions + .put(BuiltinFunctionName.IS_NULL.name().toLowerCase(Locale.ROOT), "isnull") + .put(BuiltinFunctionName.IS_NOT_NULL.name().toLowerCase(Locale.ROOT), "isnotnull") + .build(); + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { // TODO change it when UDF is supported @@ -24,6 +55,7 @@ static Expression builtinFunction(org.opensearch.sql.ast.expression.Function fun throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); } else { String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); + name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.getOrDefault(name, name); return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 1fdd20c74..ba634cc1c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers @@ -13,17 +12,19 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - test("test average price ") { + test("test average price") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = @@ -38,10 +39,10 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } - ignore("test average price with Alias") { + test("test average price with Alias") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -57,7 +58,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("test average price group by product ") { @@ -81,7 +82,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("test average price group by product and filter") { @@ -109,7 +110,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("test average price group by product and filter sorted") { @@ -144,7 +145,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite global = true, aggregatePlan) val expectedPlan = Project(star, sortedPlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl simple avg age by span of interval of 10 years query test ") { val context = new CatalystPlanContext @@ -164,7 +165,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) val expectedPlan = Project(star, aggregatePlan) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl simple avg age by span of interval of 10 years query with sort test ") { @@ -190,7 +191,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) val expectedPlan = Project(star, sortedPlan) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl simple avg age by span of interval of 10 years by country query test ") { @@ -219,7 +220,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite tableRelation) val expectedPlan = Project(star, aggregatePlan) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl query count sales by weeks window and productId with sorting test") { val context = new CatalystPlanContext @@ -257,7 +258,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val expectedPlan = Project(star, sortedPlan) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl query count sales by days window and productId with sorting test") { @@ -296,7 +297,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite aggregatePlan) val expectedPlan = Project(star, sortedPlan) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl query count status amount by day window and group by status test") { val context = new CatalystPlanContext @@ -331,7 +332,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) val expectedPlan = Project(star, planWithLimit) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test( "create ppl query count only error (status >= 400) status amount by day window and group by status test") { @@ -368,7 +369,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) val expectedPlan = Project(star, planWithLimit) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index bc31691d0..5b94ca092 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers @@ -13,11 +12,12 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.IntegerType class PPLLogicalPlanBasicQueriesTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -31,7 +31,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with escaped table name") { @@ -41,7 +41,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with schema.table and no explicit fields (defaults to all fields)") { @@ -51,7 +51,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } @@ -62,7 +62,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with one field projected") { @@ -72,7 +72,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two fields projected") { @@ -82,7 +82,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) val expectedPlan = Project(projectList, table) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with one table with two fields projected sorted by one field") { @@ -97,7 +97,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val sorted = Sort(sortOrder, true, table) val expectedPlan = Project(projectList, sorted) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test( @@ -111,7 +111,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table))) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -129,8 +129,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB)) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test( @@ -152,7 +151,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("Search multiple tables - translated into union call with fields") { @@ -172,6 +171,6 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala index 27dd972fc..fd7957106 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.apache.hadoop.conf.Configuration import org.junit.Assert.assertEquals import org.mockito.Mockito.when import org.opensearch.flint.spark.ppl.PlaneUtils.plan @@ -20,11 +19,13 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Like, Literal, NamedExpression, Not, Or, SortOrder, UnixTimestamp} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class PPLLogicalPlanFiltersTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -40,7 +41,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two field with 'and' filtered ") { @@ -54,7 +55,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(And(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two field with 'or' filtered ") { @@ -68,7 +69,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two field with 'not' filtered ") { @@ -82,7 +83,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -96,7 +97,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -111,7 +112,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -127,7 +128,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -141,7 +142,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -155,7 +156,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -169,7 +170,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -183,7 +184,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -197,7 +198,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -218,6 +219,6 @@ class PPLLogicalPlanFiltersTranslatorTestSuite Project(projectList, filterPlan)) val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala index 24336b098..73fa2a999 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -5,20 +5,20 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanMathFunctionsTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -36,7 +36,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test ceil") { @@ -50,7 +50,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test floor") { @@ -64,7 +64,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test ln") { @@ -78,7 +78,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test mod") { @@ -93,7 +93,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test pow") { @@ -107,7 +107,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test sqrt") { @@ -121,7 +121,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test arithmetic: + - * / %") { @@ -145,19 +145,52 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite isDistinct = false)), isDistinct = false) // sqrt(pow(a, 2)) / 1 - val sqrtPowDivide = UnresolvedFunction("divide", seq(sqrtPow, Literal(1)), isDistinct = false) + val sqrtPowDivide = UnresolvedFunction("/", seq(sqrtPow, Literal(1)), isDistinct = false) // sqrt(pow(a, 2)) * 1 val sqrtPowMultiply = - UnresolvedFunction("multiply", seq(sqrtPow, Literal(1)), isDistinct = false) + UnresolvedFunction("*", seq(sqrtPow, Literal(1)), isDistinct = false) // sqrt(pow(a, 2)) % 1 - val sqrtPowMod = UnresolvedFunction("modulus", seq(sqrtPow, Literal(1)), isDistinct = false) + val sqrtPowMod = UnresolvedFunction("%", seq(sqrtPow, Literal(1)), isDistinct = false) // sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 - val add = UnresolvedFunction("add", seq(sqrtPow, sqrtPowDivide), isDistinct = false) - val sub = UnresolvedFunction("subtract", seq(add, sqrtPowMultiply), isDistinct = false) + val add = UnresolvedFunction("+", seq(sqrtPow, sqrtPowDivide), isDistinct = false) + val sub = UnresolvedFunction("-", seq(add, sqrtPowMultiply), isDistinct = false) val filterExpr = EqualTo(sub, sqrtPowMod) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test boolean operators: = != < <= > >=") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = age = 30, b = age != 70, c = 30 < age, d = 30 <= age, e = 30 > age, f = 30 >= age | fields age, a, b, c, d, e, f", + false), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias(EqualTo(UnresolvedAttribute("age"), Literal(30)), "a")(), + Alias(Not(EqualTo(UnresolvedAttribute("age"), Literal(70))), "b")(), + Alias(LessThan(Literal(30), UnresolvedAttribute("age")), "c")(), + Alias(LessThanOrEqual(Literal(30), UnresolvedAttribute("age")), "d")(), + Alias(GreaterThan(Literal(30), UnresolvedAttribute("age")), "e")(), + Alias(GreaterThanOrEqual(Literal(30), UnresolvedAttribute("age")), "f")()), + table) + val projectList = Seq( + UnresolvedAttribute("age"), + UnresolvedAttribute("a"), + UnresolvedAttribute("b"), + UnresolvedAttribute("c"), + UnresolvedAttribute("d"), + UnresolvedAttribute("e"), + UnresolvedAttribute("f")) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala index 36a31862b..0d3c12b82 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} @@ -14,11 +13,13 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanStringFunctionsTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -45,7 +46,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test concat with field") { @@ -63,7 +64,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test length") { @@ -77,7 +78,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test lower") { @@ -91,7 +92,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test upper - case insensitive") { @@ -105,7 +106,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test trim") { @@ -119,7 +120,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test ltrim") { @@ -133,7 +134,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test rtrim") { @@ -147,7 +148,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test substring") { @@ -162,7 +163,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test like") { @@ -181,7 +182,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test position") { @@ -200,7 +201,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test replace") { @@ -219,6 +220,6 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala index 7cfcc33d5..cd857fc08 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq @@ -13,11 +12,13 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanTimeFunctionsTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -36,7 +37,7 @@ class PPLLogicalPlanTimeFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test unix_timestamp") { @@ -51,6 +52,174 @@ class PPLLogicalPlanTimeFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test builtin time functions with name mapping") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = t + | | eval a = DAY_OF_WEEK(DATE('2020-08-26')) + | | eval b = DAY_OF_MONTH(DATE('2020-08-26')) + | | eval c = DAY_OF_YEAR(DATE('2020-08-26')) + | | eval d = WEEK_OF_YEAR(DATE('2020-08-26')) + | | eval e = WEEK(DATE('2020-08-26')) + | | eval f = MONTH_OF_YEAR(DATE('2020-08-26')) + | | eval g = HOUR_OF_DAY(DATE('2020-08-26')) + | | eval h = MINUTE_OF_HOUR(DATE('2020-08-26')) + | | eval i = SECOND_OF_MINUTE(DATE('2020-08-26')) + | | eval j = SUBDATE(DATE('2020-08-26'), 1) + | | eval k = ADDDATE(DATE('2020-08-26'), 1) + | | eval l = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')) + | | eval m = LOCALTIME() + | """.stripMargin, + false), + context) + + val table = UnresolvedRelation(Seq("t")) + val projectA = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "dayofweek", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "a")()), + table) + val projectB = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "dayofmonth", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "b")()), + projectA) + val projectC = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "dayofyear", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "c")()), + projectB) + val projectD = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "weekofyear", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "d")()), + projectC) + val projectE = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "weekofyear", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "e")()), + projectD) + val projectF = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "month", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "f")()), + projectE) + val projectG = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "hour", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "g")()), + projectF) + val projectH = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "minute", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "h")()), + projectG) + val projectI = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "second", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "i")()), + projectH) + val projectJ = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "date_sub", + Seq( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + Literal(1)), + isDistinct = false), + "j")()), + projectI) + val projectK = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "date_add", + Seq( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + Literal(1)), + isDistinct = false), + "k")()), + projectJ) + val projectL = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "datediff", + Seq( + UnresolvedFunction( + "timestamp", + Seq(Literal("2000-01-02 00:00:00")), + isDistinct = false), + UnresolvedFunction( + "timestamp", + Seq(Literal("2000-01-01 23:59:59")), + isDistinct = false)), + isDistinct = false), + "l")()), + projectK) + val projectM = Project( + Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("localtimestamp", Seq.empty, isDistinct = false), "m")()), + projectL) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, projectM) + comparePlans(expectedPlan, logPlan, false) } } From f81c1b3f38a38be7776cc5445678891b49484658 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 2 Aug 2024 15:02:46 +0800 Subject: [PATCH 2/2] shorten the map declaration Signed-off-by: Lantao Jin --- .../ppl/utils/BuiltinFunctionTranslator.java | 77 ++++++++++++------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java index 6f17b1247..53c6673a8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -11,9 +11,28 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; -import java.util.Locale; import java.util.Map; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MODULUS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTH_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.HOUR_OF_DAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_HOUR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND_OF_MINUTE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIME; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -22,31 +41,32 @@ public interface BuiltinFunctionTranslator { /** * The name mapping between PPL builtin functions to Spark builtin functions. */ - static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING = new ImmutableMap.Builder() - // arithmetic operators - .put(BuiltinFunctionName.ADD.name().toLowerCase(Locale.ROOT), "+") - .put(BuiltinFunctionName.SUBTRACT.name().toLowerCase(Locale.ROOT), "-") - .put(BuiltinFunctionName.MULTIPLY.name().toLowerCase(Locale.ROOT), "*") - .put(BuiltinFunctionName.DIVIDE.name().toLowerCase(Locale.ROOT), "/") - .put(BuiltinFunctionName.MODULUS.name().toLowerCase(Locale.ROOT), "%") - // time functions - .put(BuiltinFunctionName.DAY_OF_WEEK.name().toLowerCase(Locale.ROOT), "dayofweek") - .put(BuiltinFunctionName.DAY_OF_MONTH.name().toLowerCase(Locale.ROOT), "dayofmonth") - .put(BuiltinFunctionName.DAY_OF_YEAR.name().toLowerCase(Locale.ROOT), "dayofyear") - .put(BuiltinFunctionName.WEEK_OF_YEAR.name().toLowerCase(Locale.ROOT), "weekofyear") - .put(BuiltinFunctionName.WEEK.name().toLowerCase(Locale.ROOT), "weekofyear") - .put(BuiltinFunctionName.MONTH_OF_YEAR.name().toLowerCase(Locale.ROOT), "month") - .put(BuiltinFunctionName.HOUR_OF_DAY.name().toLowerCase(Locale.ROOT), "hour") - .put(BuiltinFunctionName.MINUTE_OF_HOUR.name().toLowerCase(Locale.ROOT), "minute") - .put(BuiltinFunctionName.SECOND_OF_MINUTE.name().toLowerCase(Locale.ROOT), "second") - .put(BuiltinFunctionName.SUBDATE.name().toLowerCase(Locale.ROOT), "date_sub") // only maps subdate(date, days) - .put(BuiltinFunctionName.ADDDATE.name().toLowerCase(Locale.ROOT), "date_add") // only maps adddate(date, days) - .put(BuiltinFunctionName.DATEDIFF.name().toLowerCase(Locale.ROOT), "datediff") - .put(BuiltinFunctionName.LOCALTIME.name().toLowerCase(Locale.ROOT), "localtimestamp") - // condition functions - .put(BuiltinFunctionName.IS_NULL.name().toLowerCase(Locale.ROOT), "isnull") - .put(BuiltinFunctionName.IS_NOT_NULL.name().toLowerCase(Locale.ROOT), "isnotnull") - .build(); + static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING + = new ImmutableMap.Builder() + // arithmetic operators + .put(ADD, "+") + .put(SUBTRACT, "-") + .put(MULTIPLY, "*") + .put(DIVIDE, "/") + .put(MODULUS, "%") + // time functions + .put(DAY_OF_WEEK, "dayofweek") + .put(DAY_OF_MONTH, "dayofmonth") + .put(DAY_OF_YEAR, "dayofyear") + .put(WEEK_OF_YEAR, "weekofyear") + .put(WEEK, "weekofyear") + .put(MONTH_OF_YEAR, "month") + .put(HOUR_OF_DAY, "hour") + .put(MINUTE_OF_HOUR, "minute") + .put(SECOND_OF_MINUTE, "second") + .put(SUBDATE, "date_sub") // only maps subdate(date, days) + .put(ADDDATE, "date_add") // only maps adddate(date, days) + .put(DATEDIFF, "datediff") + .put(LOCALTIME, "localtimestamp") + //condition functions + .put(IS_NULL, "isnull") + .put(IS_NOT_NULL, "isnotnull") + .build(); static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { @@ -54,8 +74,9 @@ static Expression builtinFunction(org.opensearch.sql.ast.expression.Function fun // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); } else { - String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); - name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.getOrDefault(name, name); + BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); + String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING + .getOrDefault(builtin, builtin.getName().getFunctionName()); return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } }