From 4c48182f996beada9f5a814ed13de5dff4e02394 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 24 Jul 2024 13:35:29 +0800 Subject: [PATCH 1/3] Add string functions and math functions Signed-off-by: Lantao Jin --- .../src/main/antlr4/OpenSearchPPLParser.g4 | 5 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 17 +- .../ppl/utils/BuiltinFunctionTranslator.java | 29 +++ .../sql/ppl/utils/DataTypeTransformer.java | 4 +- .../spark/ppl/LogicalPlanTestUtils.scala | 4 +- ...PlanMathFunctionsTranslatorTestSuite.scala | 156 ++++++++++++ ...anStringFunctionsTranslatorTestSuite.scala | 224 ++++++++++++++++++ 7 files changed, 433 insertions(+), 6 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index f6cd0d4ee..83c735e17 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -263,17 +263,20 @@ logicalExpression comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr + | valueExpression IN valueList # inExpr ; valueExpression : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic | primaryExpression # valueExpressionDefault + | positionFunction # positionFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr ; primaryExpression - : fieldExpression + : evalFunctionCall + | fieldExpression | literalValue ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 6d14db328..04f4320c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -54,6 +54,7 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ppl.utils.AggregatorTranslator; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; @@ -397,7 +398,21 @@ public Expression visitEval(Eval node, CatalystPlanContext context) { @Override public Expression visitFunction(Function node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Function"); + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java new file mode 100644 index 000000000..26a4288d7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface BuiltinFunctionTranslator { + + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { + if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { + // TODO should we support UDF in future? + throw new IllegalStateException("Unknown builtin function: " + function.getFuncName()); + } else { + String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 0c7269a07..848395f22 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -33,8 +33,8 @@ * translate the PPL ast expressions data-types into catalyst data-types */ public interface DataTypeTransformer { - static Seq seq(T element) { - return seq(List.of(element)); + static Seq seq(T... elements) { + return seq(List.of(elements)); } static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala index a36b34ef4..ec68f538b 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala @@ -5,13 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} /** * general utility functions for ppl to spark transformation test */ -trait LogicalPlanTestUtils { +trait LogicalPlanTestUtils extends AnalysisTest { /** * utility function to compare two logical plans while ignoring the auto-generated expressionId @@ -52,5 +53,4 @@ trait LogicalPlanTestUtils { // Return the string representation of the transformed plan transformedPlan.toString } - } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..8aa0f61ac --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanMathFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test abs") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = abs(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("abs", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ceil") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ceil(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ceil", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test floor") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = floor(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("floor", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ln") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ln(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ln", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test mod") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = mod(10, 4)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("mod", seq(Literal(10), Literal(4)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test pow") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = pow(2, 3)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("pow", seq(Literal(2), Literal(3)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test sqrt") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = sqrt(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("sqrt", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test arithmetic: + - * / %") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t a = b % 2 + 1 * 5 + 10 / 2", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "add", + seq( + UnresolvedFunction( + "add", + seq( + UnresolvedFunction( + "modulus", + seq(UnresolvedAttribute("b"), Literal(2)), + isDistinct = false), + UnresolvedFunction("multiply", seq(Literal(1), Literal(5)), isDistinct = false)), + isDistinct = false), + UnresolvedFunction("divide", seq(Literal(10), Literal(2)), isDistinct = false)), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..36a31862b --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala @@ -0,0 +1,224 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanStringFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test unknown function") { + val context = new CatalystPlanContext + intercept[SyntaxCheckException] { + planTransformer.visit(plan(pplParser, "source=t a = unknown(b)", false), context) + } + } + + test("test concat") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a = CONCAT('hello', 'world')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("concat", seq(Literal("hello"), Literal("world")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test concat with field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = CONCAT('hello', b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "concat", + seq(Literal("hello"), UnresolvedAttribute("b")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test length") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = LENGTH(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("length", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test lower") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = LOWER(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("lower", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test upper - case insensitive") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = uPPer(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("upper", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test trim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = trim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("trim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ltrim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ltrim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ltrim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test rtrim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = rtrim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("rtrim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test substring") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = substring(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("substring", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test like") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a=like(b, 'Hatti_')", false), context) + + val table = UnresolvedRelation(Seq("t")) + val likeExpr = new Like(UnresolvedAttribute("a"), Literal("Hatti_")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "like", + seq(UnresolvedAttribute("b"), Literal("Hatti_")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test position") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a=position('world' IN 'helloworld')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "position", + seq(Literal("world"), Literal("helloworld")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test replace") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a=replace('hello', 'l', 'x')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "replace", + seq(Literal("hello"), Literal("l"), Literal("x")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} From 2e96c9ef38c23e887608560cb0eac5443f5f49df Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 24 Jul 2024 16:23:51 +0800 Subject: [PATCH 2/3] add from_unixtime and unix_timestamp test Signed-off-by: Lantao Jin --- .../ppl/utils/BuiltinFunctionTranslator.java | 5 +- ...PlanTimeFunctionsTranslatorTestSuite.scala | 58 +++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java index 26a4288d7..0d57fea20 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -19,8 +19,9 @@ public interface BuiltinFunctionTranslator { static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { - // TODO should we support UDF in future? - throw new IllegalStateException("Unknown builtin function: " + function.getFuncName()); + // TODO change it when UDF is supported + // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions + throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); } else { String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..812b2e24b --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanTimeFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers{ + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test from_unixtime") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a = from_unixtime(b)", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("from_unixtime", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test unix_timestamp") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a = unix_timestamp(b)", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("unix_timestamp", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} From 333aaee966299203b8ab437ff64f58ce3e1a109e Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 25 Jul 2024 22:23:42 +0800 Subject: [PATCH 3/3] Add IT Signed-off-by: Lantao Jin --- .../FlintSparkPPLBuiltinFunctionITSuite.scala | 554 ++++++++++++++++++ .../sql/ppl/utils/DataTypeTransformer.java | 15 + .../spark/ppl/LogicalPlanTestUtils.scala | 3 +- ...PlanMathFunctionsTranslatorTestSuite.scala | 37 +- ...PlanTimeFunctionsTranslatorTestSuite.scala | 12 +- 5 files changed, 597 insertions(+), 24 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala new file mode 100644 index 000000000..127b29295 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala @@ -0,0 +1,554 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.DoubleType + +class FlintSparkPPLBuiltinFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test string functions - concat") { + val frame = sql(s""" + | source = $testTable name=concat('He', 'llo') | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("name"), + UnresolvedFunction("concat", seq(Literal("He"), Literal("llo")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - concat with field") { + val frame = sql(s""" + | source = $testTable name=concat('Hello', state) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array.empty + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("name"), + UnresolvedFunction( + "concat", + seq(Literal("Hello"), UnresolvedAttribute("state")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - length") { + val frame = sql(s""" + | source = $testTable |where length(name) = 5 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal(5)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test function name should be insensitive") { + val frame = sql(s""" + | source = $testTable |where leNgTh(name) = 5 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal(5)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - lower") { + val frame = sql(s""" + | source = $testTable |where lower(name) = "hello" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("lower", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("hello")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - upper") { + val frame = sql(s""" + | source = $testTable |where upper(name) = upper("hello") | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + UnresolvedFunction("upper", seq(Literal("hello")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - substring") { + val frame = sql(s""" + | source = $testTable |where substring(name, 2, 2) = "el" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "substring", + seq(UnresolvedAttribute("name"), Literal(2), Literal(2)), + isDistinct = false), + Literal("el")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - like") { + val frame = sql(s""" + | source = $testTable | where like(name, '_ello%') | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val likeFunction = UnresolvedFunction( + "like", + seq(UnresolvedAttribute("name"), Literal("_ello%")), + isDistinct = false) + + val filterPlan = Filter(likeFunction, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - replace") { + val frame = sql(s""" + | source = $testTable |where replace(name, 'o', ' ') = "Hell " | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "replace", + seq(UnresolvedAttribute("name"), Literal("o"), Literal(" ")), + isDistinct = false), + Literal("Hell ")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - replace and trim") { + val frame = sql(s""" + | source = $testTable |where trim(replace(name, 'o', ' ')) = "Hell" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "trim", + seq( + UnresolvedFunction( + "replace", + seq(UnresolvedAttribute("name"), Literal("o"), Literal(" ")), + isDistinct = false)), + isDistinct = false), + Literal("Hell")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - abs") { + val frame = sql(s""" + | source = $testTable |where age = abs(-30) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("abs", seq(Literal(-30)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - abs with field") { + val frame = sql(s""" + | source = $testTable |where abs(age) = 30 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("abs", seq(UnresolvedAttribute("age")), isDistinct = false), + Literal(30)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - ceil") { + val frame = sql(s""" + | source = $testTable |where age = ceil(29.7) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("ceil", seq(Literal(29.7d, DoubleType)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - floor") { + val frame = sql(s""" + | source = $testTable |where age = floor(30.4) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("floor", seq(Literal(30.4d, DoubleType)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - ln") { + val frame = sql(s""" + | source = $testTable |where ln(age) > 4 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), + Literal(4)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - mod") { + val frame = sql(s""" + | source = $testTable |where mod(age, 10) = 0 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("mod", seq(UnresolvedAttribute("age"), Literal(10)), isDistinct = false), + Literal(0)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - pow and sqrt") { + val frame = sql(s""" + | source = $testTable |where sqrt(pow(age, 2)) = 30.0 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "sqrt", + seq( + UnresolvedFunction( + "pow", + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false)), + isDistinct = false), + Literal(30.0)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test time functions - from_unixtime and unix_timestamp") { + val frame = sql(s""" + | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction( + "unix_timestamp", + seq(UnresolvedFunction("from_unixtime", seq(Literal(1700000001)), isDistinct = false)), + isDistinct = false), + Literal(1700000000)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 848395f22..4345b0897 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -6,10 +6,15 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.types.BooleanType$; import org.apache.spark.sql.types.ByteType$; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; @@ -46,6 +51,16 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return DateType$.MODULE$; case INTEGER: return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case BOOLEAN: + return BooleanType$.MODULE$; + case SHORT: + return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; default: diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala index ec68f538b..0c116a728 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala @@ -5,14 +5,13 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} /** * general utility functions for ppl to spark transformation test */ -trait LogicalPlanTestUtils extends AnalysisTest { +trait LogicalPlanTestUtils { /** * utility function to compare two logical plans while ignoring the auto-generated expressionId diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala index 8aa0f61ac..24336b098 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -128,26 +128,33 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t a = b % 2 + 1 * 5 + 10 / 2", false), + plan( + pplParser, + "source=t | where sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 - sqrt(pow(a, 2)) * 1 = sqrt(pow(a, 2)) % 1", + false), context) - val table = UnresolvedRelation(Seq("t")) - val filterExpr = EqualTo( - UnresolvedAttribute("a"), + // sqrt(pow(a, 2)) + val sqrtPow = UnresolvedFunction( - "add", + "sqrt", seq( UnresolvedFunction( - "add", - seq( - UnresolvedFunction( - "modulus", - seq(UnresolvedAttribute("b"), Literal(2)), - isDistinct = false), - UnresolvedFunction("multiply", seq(Literal(1), Literal(5)), isDistinct = false)), - isDistinct = false), - UnresolvedFunction("divide", seq(Literal(10), Literal(2)), isDistinct = false)), - isDistinct = false)) + "pow", + seq(UnresolvedAttribute("a"), Literal(2)), + isDistinct = false)), + isDistinct = false) + // sqrt(pow(a, 2)) / 1 + val sqrtPowDivide = UnresolvedFunction("divide", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) * 1 + val sqrtPowMultiply = + UnresolvedFunction("multiply", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) % 1 + val sqrtPowMod = UnresolvedFunction("modulus", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 + val add = UnresolvedFunction("add", seq(sqrtPow, sqrtPowDivide), isDistinct = false) + val sub = UnresolvedFunction("subtract", seq(add, sqrtPowMultiply), isDistinct = false) + val filterExpr = EqualTo(sub, sqrtPowMod) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala index 812b2e24b..7cfcc33d5 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala @@ -19,16 +19,15 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanTimeFunctionsTranslatorTestSuite extends SparkFunSuite with LogicalPlanTestUtils - with Matchers{ + with Matchers { private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() test("test from_unixtime") { val context = new CatalystPlanContext - val logPlan = planTransformer.visit( - plan(pplParser, "source=t a = from_unixtime(b)", false), - context) + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = from_unixtime(b)", false), context) val table = UnresolvedRelation(Seq("t")) val filterExpr = EqualTo( @@ -42,9 +41,8 @@ class PPLLogicalPlanTimeFunctionsTranslatorTestSuite test("test unix_timestamp") { val context = new CatalystPlanContext - val logPlan = planTransformer.visit( - plan(pplParser, "source=t a = unix_timestamp(b)", false), - context) + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = unix_timestamp(b)", false), context) val table = UnresolvedRelation(Seq("t")) val filterExpr = EqualTo(