diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala new file mode 100644 index 000000000..127b29295 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala @@ -0,0 +1,554 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.DoubleType + +class FlintSparkPPLBuiltinFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test string functions - concat") { + val frame = sql(s""" + | source = $testTable name=concat('He', 'llo') | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("name"), + UnresolvedFunction("concat", seq(Literal("He"), Literal("llo")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - concat with field") { + val frame = sql(s""" + | source = $testTable name=concat('Hello', state) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array.empty + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("name"), + UnresolvedFunction( + "concat", + seq(Literal("Hello"), UnresolvedAttribute("state")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - length") { + val frame = sql(s""" + | source = $testTable |where length(name) = 5 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal(5)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test function name should be insensitive") { + val frame = sql(s""" + | source = $testTable |where leNgTh(name) = 5 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal(5)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - lower") { + val frame = sql(s""" + | source = $testTable |where lower(name) = "hello" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("lower", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("hello")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - upper") { + val frame = sql(s""" + | source = $testTable |where upper(name) = upper("hello") | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + UnresolvedFunction("upper", seq(Literal("hello")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - substring") { + val frame = sql(s""" + | source = $testTable |where substring(name, 2, 2) = "el" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "substring", + seq(UnresolvedAttribute("name"), Literal(2), Literal(2)), + isDistinct = false), + Literal("el")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - like") { + val frame = sql(s""" + | source = $testTable | where like(name, '_ello%') | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val likeFunction = UnresolvedFunction( + "like", + seq(UnresolvedAttribute("name"), Literal("_ello%")), + isDistinct = false) + + val filterPlan = Filter(likeFunction, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - replace") { + val frame = sql(s""" + | source = $testTable |where replace(name, 'o', ' ') = "Hell " | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "replace", + seq(UnresolvedAttribute("name"), Literal("o"), Literal(" ")), + isDistinct = false), + Literal("Hell ")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - replace and trim") { + val frame = sql(s""" + | source = $testTable |where trim(replace(name, 'o', ' ')) = "Hell" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "trim", + seq( + UnresolvedFunction( + "replace", + seq(UnresolvedAttribute("name"), Literal("o"), Literal(" ")), + isDistinct = false)), + isDistinct = false), + Literal("Hell")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - abs") { + val frame = sql(s""" + | source = $testTable |where age = abs(-30) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("abs", seq(Literal(-30)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - abs with field") { + val frame = sql(s""" + | source = $testTable |where abs(age) = 30 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("abs", seq(UnresolvedAttribute("age")), isDistinct = false), + Literal(30)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - ceil") { + val frame = sql(s""" + | source = $testTable |where age = ceil(29.7) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("ceil", seq(Literal(29.7d, DoubleType)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - floor") { + val frame = sql(s""" + | source = $testTable |where age = floor(30.4) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("floor", seq(Literal(30.4d, DoubleType)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - ln") { + val frame = sql(s""" + | source = $testTable |where ln(age) > 4 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), + Literal(4)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - mod") { + val frame = sql(s""" + | source = $testTable |where mod(age, 10) = 0 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("mod", seq(UnresolvedAttribute("age"), Literal(10)), isDistinct = false), + Literal(0)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - pow and sqrt") { + val frame = sql(s""" + | source = $testTable |where sqrt(pow(age, 2)) = 30.0 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "sqrt", + seq( + UnresolvedFunction( + "pow", + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false)), + isDistinct = false), + Literal(30.0)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test time functions - from_unixtime and unix_timestamp") { + val frame = sql(s""" + | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction( + "unix_timestamp", + seq(UnresolvedFunction("from_unixtime", seq(Literal(1700000001)), isDistinct = false)), + isDistinct = false), + Literal(1700000000)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 086413ca4..aac3c3f36 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -254,19 +254,25 @@ logicalExpression | left = logicalExpression OR right = logicalExpression # logicalOr | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression XOR right = logicalExpression # logicalXor + | booleanExpression # booleanExpr ; comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr + | valueExpression IN valueList # inExpr ; valueExpression - : primaryExpression # valueExpressionDefault + : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic + | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic + | primaryExpression # valueExpressionDefault + | positionFunction # positionFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr ; primaryExpression - : fieldExpression + : evalFunctionCall + | fieldExpression | literalValue ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 6d14db328..04f4320c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -54,6 +54,7 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ppl.utils.AggregatorTranslator; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; @@ -397,7 +398,21 @@ public Expression visitEval(Eval node, CatalystPlanContext context) { @Override public Expression visitFunction(Function node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Function"); + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6f1129b04..92e9dd458 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -112,6 +112,15 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right)); } + /** + * Value Expression. + */ + @Override + public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { + return new Function( + ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + } + @Override public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { return visit(ctx.valueExpression()); // Discard parenthesis around diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java new file mode 100644 index 000000000..0d57fea20 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface BuiltinFunctionTranslator { + + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { + if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { + // TODO change it when UDF is supported + // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions + throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + } else { + String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 0c7269a07..4345b0897 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -6,10 +6,15 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.types.BooleanType$; import org.apache.spark.sql.types.ByteType$; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; @@ -33,8 +38,8 @@ * translate the PPL ast expressions data-types into catalyst data-types */ public interface DataTypeTransformer { - static Seq seq(T element) { - return seq(List.of(element)); + static Seq seq(T... elements) { + return seq(List.of(elements)); } static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); @@ -46,6 +51,16 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return DateType$.MODULE$; case INTEGER: return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case BOOLEAN: + return BooleanType$.MODULE$; + case SHORT: + return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; default: diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala index a36b34ef4..0c116a728 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala @@ -52,5 +52,4 @@ trait LogicalPlanTestUtils { // Return the string representation of the transformed plan transformedPlan.toString } - } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..24336b098 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanMathFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test abs") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = abs(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("abs", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ceil") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ceil(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ceil", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test floor") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = floor(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("floor", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ln") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ln(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ln", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test mod") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = mod(10, 4)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("mod", seq(Literal(10), Literal(4)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test pow") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = pow(2, 3)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("pow", seq(Literal(2), Literal(3)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test sqrt") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = sqrt(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("sqrt", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test arithmetic: + - * / %") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 - sqrt(pow(a, 2)) * 1 = sqrt(pow(a, 2)) % 1", + false), + context) + val table = UnresolvedRelation(Seq("t")) + // sqrt(pow(a, 2)) + val sqrtPow = + UnresolvedFunction( + "sqrt", + seq( + UnresolvedFunction( + "pow", + seq(UnresolvedAttribute("a"), Literal(2)), + isDistinct = false)), + isDistinct = false) + // sqrt(pow(a, 2)) / 1 + val sqrtPowDivide = UnresolvedFunction("divide", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) * 1 + val sqrtPowMultiply = + UnresolvedFunction("multiply", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) % 1 + val sqrtPowMod = UnresolvedFunction("modulus", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 + val add = UnresolvedFunction("add", seq(sqrtPow, sqrtPowDivide), isDistinct = false) + val sub = UnresolvedFunction("subtract", seq(add, sqrtPowMultiply), isDistinct = false) + val filterExpr = EqualTo(sub, sqrtPowMod) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..36a31862b --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala @@ -0,0 +1,224 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanStringFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test unknown function") { + val context = new CatalystPlanContext + intercept[SyntaxCheckException] { + planTransformer.visit(plan(pplParser, "source=t a = unknown(b)", false), context) + } + } + + test("test concat") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a = CONCAT('hello', 'world')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("concat", seq(Literal("hello"), Literal("world")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test concat with field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = CONCAT('hello', b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "concat", + seq(Literal("hello"), UnresolvedAttribute("b")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test length") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = LENGTH(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("length", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test lower") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = LOWER(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("lower", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test upper - case insensitive") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = uPPer(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("upper", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test trim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = trim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("trim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ltrim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ltrim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ltrim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test rtrim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = rtrim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("rtrim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test substring") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = substring(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("substring", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test like") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a=like(b, 'Hatti_')", false), context) + + val table = UnresolvedRelation(Seq("t")) + val likeExpr = new Like(UnresolvedAttribute("a"), Literal("Hatti_")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "like", + seq(UnresolvedAttribute("b"), Literal("Hatti_")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test position") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a=position('world' IN 'helloworld')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "position", + seq(Literal("world"), Literal("helloworld")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test replace") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a=replace('hello', 'l', 'x')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "replace", + seq(Literal("hello"), Literal("l"), Literal("x")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..7cfcc33d5 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanTimeFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test from_unixtime") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = from_unixtime(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("from_unixtime", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test unix_timestamp") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = unix_timestamp(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("unix_timestamp", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +}