diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 3f843dbe4..8138317fc 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -548,4 +548,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | ) |)""".stripMargin) } + + protected def createTableHttpLog(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + |( + | id INT, + | status_code INT, + | request_path STRING, + | timestamp STRING + |) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, 200, '/home', '2023-10-01 10:00:00'), + | (2, null, '/about', '2023-10-01 10:05:00'), + | (3, 500, '/contact', '2023-10-01 10:10:00'), + | (4, 301, '/home', '2023-10-01 10:15:00'), + | (5, 200, '/services', '2023-10-01 10:20:00'), + | (6, 403, '/home', '2023-10-01 10:25:00'), + | """.stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index ea77ff990..d0a7d6b02 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest @@ -21,12 +21,14 @@ class FlintSparkPPLEvalITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createTableHttpLog(testTableHttpLog) } protected override def afterEach(): Unit = { @@ -504,7 +506,63 @@ class FlintSparkPPLEvalITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("eval case function") { + val frame = sql(s""" + | source = $testTableHttpLog | + | eval status_category = + | case(status_code >= 200 AND status_code < 300, 'Success', + | status_code >= 300 AND status_code < 400, 'Redirection', + | status_code >= 400 AND status_code < 500, 'Client Error', + | status_code >= 500, 'Server Error' + | else 'Unknown' + | ) + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, 200, "/home", "2023-10-01 10:00:00", "Success"), + Row(2, null, "/about", "2023-10-01 10:05:00", "Unknown"), + Row(3, 500, "/contact", "2023-10-01 10:10:00", "Server Error"), + Row(4, 301, "/home", "2023-10-01 10:15:00", "Redirection"), + Row(5, 200, "/services", "2023-10-01 10:20:00", "Success"), + Row(6, 403, "/home", "2023-10-01 10:25:00", "Client Error")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getInt(0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val expectedColumns = + Array[String]("id", "status_code", "request_path", "timestamp", "status_category") + assert(frame.columns.sameElements(expectedColumns)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log")) + val conditionValueSequence = Seq( + (graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")), + (graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")), + (graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")), + ( + EqualTo( + Literal(true), + GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))), + Literal("Server Error"))) + val elseValue = Literal("Unknown") + val caseFunction = CaseWhen(conditionValueSequence, elseValue) + val aliasStatusCategory = Alias(caseFunction, "status_category")() + val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory) + val evalProject = Project(evalProjectList, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), evalProject) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def graterOrEqualAndLessThan(fieldName: String, min: Int, max: Int) = { + val and = And( + GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(min)), + LessThan(UnresolvedAttribute(fieldName), Literal(max))) + EqualTo(Literal(true), and) + } + // Todo excluded fields not support yet + ignore("test single eval expression with excluded fields") { val frame = sql(s""" | source = $testTable | eval new_field = "New Field" | fields - age diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 9a21bb45a..310e09c3d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -348,4 +348,38 @@ class FlintSparkPPLFiltersITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test("case function used as filter") { + val frame = sql(s""" + | source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America' + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + val sorted = results.sorted + assert(sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val conditionValueSequence = Seq( + ( + EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + Literal("The United States of America"))) + val elseValue = Literal("Other country") + val caseFunction = CaseWhen(conditionValueSequence, elseValue) + val filterExpr = EqualTo(caseFunction, Literal("The United States of America")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 0c83abd97..3953466b3 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -245,6 +245,7 @@ See the next samples of PPL queries : - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` + - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`; **Filters With Logical Conditions** - `source = table | where c = 'test' AND a = 1 | fields a,b,c` @@ -265,6 +266,15 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval f = ispresent(a)` - `source = table | eval r = coalesce(a, b, c) | fields r` - `source = table | eval e = isempty(a) | fields e` + ``` + source = table | eval e = eval status_category = + case(a >= 200 AND a < 300, 'Success', + a >= 300 AND a < 400, 'Redirection', + a >= 400 AND a < 500, 'Client Error', + a >= 500, 'Server Error' + else 'Unknown' + ) + ``` Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous" - `source = table | eval a = 10 | fields a,b,c` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 0fcb069c5..f84928178 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -98,6 +98,7 @@ ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; +ELSE: 'ELSE'; IN: 'IN'; // LOGICAL KEYWORDS diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index e7d7321a2..306a4c6e9 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -312,6 +312,7 @@ valueExpression | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic | primaryExpression # valueExpressionDefault | positionFunction # positionFunctionCall + | caseFunction # caseExpr | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr ; @@ -333,6 +334,10 @@ booleanExpression : ISEMPTY LT_PRTHS functionArg RT_PRTHS ; + caseFunction + : CASE LT_PRTHS logicalExpression COMMA valueExpression (COMMA logicalExpression COMMA valueExpression)* (ELSE valueExpression)? RT_PRTHS + ; + relevanceExpression : singleFieldRelevanceFunction | multiFieldRelevanceFunction diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 601ebd540..e7aecf495 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -171,6 +171,8 @@ public T visitIsEmpty(IsEmpty node, C context) { return visitChildren(node, context); } + // TODO add case + public T visitWindowFunction(WindowFunction node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 2aa99cd67..36fbdac4d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -611,6 +611,8 @@ public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { @Override public Expression visitCase(Case node, CatalystPlanContext context) { + Stack initialNameExpressions = new Stack<>(); + initialNameExpressions.addAll(context.getNamedParseExpressions()); analyze(node.getElseClause(), context); Expression elseValue = context.getNamedParseExpressions().pop(); List> whens = new ArrayList<>(); @@ -633,6 +635,7 @@ public Expression visitCase(Case node, CatalystPlanContext context) { } context.retainAllNamedParseExpressions(e -> e); } + context.setNamedParseExpressions(initialNameExpressions); return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue))); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index cf641bde8..17778fb6c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -44,6 +44,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; @@ -199,6 +200,25 @@ public UnresolvedExpression visitBooleanFunctionCall(OpenSearchPPLParser.Boolean ctx.functionArgs().functionArg()); } + @Override + public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ctx) { + List whens = IntStream.range(0, ctx.caseFunction().logicalExpression().size()) + .mapToObj(index -> { + OpenSearchPPLParser.LogicalExpressionContext logicalExpressionContext = ctx.caseFunction().logicalExpression(index); + OpenSearchPPLParser.ValueExpressionContext valueExpressionContext = ctx.caseFunction().valueExpression(index); + UnresolvedExpression condition = visit(logicalExpressionContext); + UnresolvedExpression result = visit(valueExpressionContext); + return new When(condition, result); + }) + .collect(Collectors.toList()); + UnresolvedExpression elseValue = new Literal(null, DataType.NULL); + if(ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { + // else value is present + elseValue = visit(ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1)); + } + return new Case(new Literal(true, DataType.BOOLEAN), whens, elseValue); + } + @Override public UnresolvedExpression visitIsEmptyExpression(OpenSearchPPLParser.IsEmptyExpressionContext ctx) { Function trimFunction = new Function(TRIM.getName().getFunctionName(), Collections.singletonList(this.visitFunctionArg(ctx.functionArg())));