From f475c4394f96d4e6904dcdca08ab9ffda2333b48 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 14 Nov 2024 09:03:51 +0800 Subject: [PATCH] Support parenthesized expression in filter (#888) (cherry picked from commit a80aa04b9bbadfe68a3ed0ff59c3edce5251aac6) --- docs/ppl-lang/PPL-Example-Commands.md | 4 + docs/ppl-lang/ppl-where-command.md | 13 +- .../ppl/FlintSparkPPLFiltersITSuite.scala | 92 +++++++ .../src/main/antlr4/OpenSearchPPLParser.g4 | 1 + .../sql/ppl/parser/AstExpressionBuilder.java | 5 + ...lPlanParenthesizedConditionTestSuite.scala | 244 ++++++++++++++++++ 6 files changed, 351 insertions(+), 8 deletions(-) create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 4ea564111..7a766ac61 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -50,6 +50,10 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - `source = table | where c = 'test' | fields a,b,c | head 3` +- `source = table | where c = 'test' AND a = 1 | fields a,b,c` +- `source = table | where c != 'test' OR a > 1 | fields a,b,c` +- `source = table | where (b > 1 OR a > 1) AND c != 'test' | fields a,b,c` +- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - Note: "AND" is optional - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` diff --git a/docs/ppl-lang/ppl-where-command.md b/docs/ppl-lang/ppl-where-command.md index c954623c3..aa7d9299e 100644 --- a/docs/ppl-lang/ppl-where-command.md +++ b/docs/ppl-lang/ppl-where-command.md @@ -27,15 +27,15 @@ PPL query: ### Additional Examples #### **Filters With Logical Conditions** -``` -- `source = table | where c = 'test' AND a = 1 | fields a,b,c` -- `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` -- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - `source = table | where a = 1 | fields a,b,c` - `source = table | where a >= 1 | fields a,b,c` - `source = table | where a < 1 | fields a,b,c` - `source = table | where b != 'test' | fields a,b,c` - `source = table | where c = 'test' | fields a,b,c | head 3` +- `source = table | where c = 'test' AND a = 1 | fields a,b,c` +- `source = table | where c != 'test' OR a > 1 | fields a,b,c` +- `source = table | where (b > 1 OR a > 1) AND c != 'test' | fields a,b,c` +- `source = table | where c = 'test' NOT a > 1 | fields a,b,c` - Note: "AND" is optional - `source = table | where ispresent(b)` - `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` - `source = table | where isempty(a)` @@ -45,7 +45,6 @@ PPL query: - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - - `source = table | eval status_category = case(a >= 200 AND a < 300, 'Success', a >= 300 AND a < 400, 'Redirection', @@ -57,10 +56,8 @@ PPL query: a >= 400 AND a < 500, 'Client Error', a >= 500, 'Server Error' else 'Incorrect HTTP status code' - ) = 'Incorrect HTTP status code' - + ) = 'Incorrect HTTP status code'` - `source = table | eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1) | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' | stats count() by factor` -``` \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index f2d7ee844..62c735597 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -467,4 +467,96 @@ class FlintSparkPPLFiltersITSuite val expectedPlan = Project(Seq(UnresolvedAttribute("state")), filter) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test parenthesis in filter") { + val frame = sql(s""" + | source = $testTable | where country = 'Canada' or age > 60 and age < 25 | fields name, age, country + | """.stripMargin) + assertSameRows(Seq(Row("John", 25, "Canada"), Row("Jane", 20, "Canada")), frame) + + val frameWithParenthesis = sql(s""" + | source = $testTable | where (country = 'Canada' or age > 60) and age < 25 | fields name, age, country + | """.stripMargin) + assertSameRows(Seq(Row("Jane", 20, "Canada")), frameWithParenthesis) + + val logicalPlan: LogicalPlan = frameWithParenthesis.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filter = Filter( + And( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("Canada")), + GreaterThan(UnresolvedAttribute("age"), Literal(60))), + LessThan(UnresolvedAttribute("age"), Literal(25))), + table) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("country")), + filter) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex and nested parenthesis in filter") { + val frame1 = sql(s""" + | source = $testTable | WHERE (age > 18 AND (state = 'California' OR state = 'New York')) + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame1) + + val frame2 = sql(s""" + | source = $testTable | WHERE ((((age > 18) AND ((((state = 'California') OR state = 'New York')))))) + | """.stripMargin) + assertSameRows( + Seq( + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame2) + + val frame3 = sql(s""" + | source = $testTable | WHERE (year = 2023 AND (month BETWEEN 1 AND 6)) AND (age >= 31 OR country = 'Canada') + | """.stripMargin) + assertSameRows( + Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4)), + frame3) + + val frame4 = sql(s""" + | source = $testTable | WHERE ((state = 'Texas' OR state = 'California') AND (age < 30 OR (country = 'USA' AND year > 2020))) + | """.stripMargin) + assertSameRows(Seq(Row("Jake", 70, "California", "USA", 2023, 4)), frame4) + + val frame5 = sql(s""" + | source = $testTable | WHERE (LIKE(LOWER(name), 'a%') OR LIKE(LOWER(name), 'j%')) AND (LENGTH(state) > 6 OR (country = 'USA' AND age > 18)) + | """.stripMargin) + assertSameRows( + Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jake", 70, "California", "USA", 2023, 4)), + frame5) + + val frame6 = sql(s""" + | source = $testTable | WHERE (age BETWEEN 25 AND 40) AND ((state IN ('California', 'New York', 'Texas') AND year = 2023) OR (country != 'USA' AND (month = 1 OR month = 12))) + | """.stripMargin) + assertSameRows(Seq(Row("Hello", 30, "New York", "USA", 2023, 4)), frame6) + + val frame7 = sql(s""" + | source = $testTable | WHERE NOT (age < 18 OR (state = 'Alaska' AND year < 2020)) AND (country = 'USA' OR (country = 'Mexico' AND month BETWEEN 6 AND 8)) + | """.stripMargin) + assertSameRows( + Seq( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4)), + frame7) + + val frame8 = sql(s""" + | source = $testTable | WHERE (NOT (year < 2020 OR age < 18)) AND ((state = 'Texas' AND month % 2 = 0) OR (country = 'Mexico' AND (year = 2023 OR (year = 2022 AND month > 6)))) + | """.stripMargin) + assertSameRows(Seq(), frame8) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 1cfd172f7..1ef0212de 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -422,6 +422,7 @@ expression logicalExpression : NOT logicalExpression # logicalNot + | LT_PRTHS logicalExpression RT_PRTHS # parentheticLogicalExpr | comparisonExpression # comparsion | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression OR right = logicalExpression # logicalOr diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4b7c8a1c1..8f8b2d27d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -157,6 +157,11 @@ public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArit ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); } + @Override + public UnresolvedExpression visitParentheticLogicalExpr(OpenSearchPPLParser.ParentheticLogicalExprContext ctx) { + return visit(ctx.logicalExpression()); // Discard parenthesis around + } + @Override public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { return visit(ctx.valueExpression()); // Discard parenthesis around diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala new file mode 100644 index 000000000..a70415aab --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParenthesizedConditionTestSuite.scala @@ -0,0 +1,244 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Not, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanParenthesizedConditionTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test simple nested condition") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (age > 18 AND (state = 'California' OR state = 'New York'))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + GreaterThan(UnresolvedAttribute("age"), Literal(18)), + Or( + EqualTo(UnresolvedAttribute("state"), Literal("California")), + EqualTo(UnresolvedAttribute("state"), Literal("New York")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test nested condition with duplicated parentheses") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE ((((age > 18) AND ((((state = 'California') OR state = 'New York'))))))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + GreaterThan(UnresolvedAttribute("age"), Literal(18)), + Or( + EqualTo(UnresolvedAttribute("state"), Literal("California")), + EqualTo(UnresolvedAttribute("state"), Literal("New York")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test combining between function") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (year = 2023 AND (month BETWEEN 1 AND 6)) AND (age >= 31 OR country = 'Canada')"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val betweenCondition = And( + GreaterThanOrEqual(UnresolvedAttribute("month"), Literal(1)), + LessThanOrEqual(UnresolvedAttribute("month"), Literal(6))) + val filter = Filter( + And( + And(EqualTo(UnresolvedAttribute("year"), Literal(2023)), betweenCondition), + Or( + GreaterThanOrEqual(UnresolvedAttribute("age"), Literal(31)), + EqualTo(UnresolvedAttribute("country"), Literal("Canada")))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple levels of nesting") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE ((state = 'Texas' OR state = 'California') AND (age < 30 OR (country = 'USA' AND year > 2020)))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Or( + EqualTo(UnresolvedAttribute("state"), Literal("Texas")), + EqualTo(UnresolvedAttribute("state"), Literal("California"))), + Or( + LessThan(UnresolvedAttribute("age"), Literal(30)), + And( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + GreaterThan(UnresolvedAttribute("year"), Literal(2020))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test with string functions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (LIKE(LOWER(name), 'a%') OR LIKE(LOWER(name), 'j%')) AND (LENGTH(state) > 6 OR (country = 'USA' AND age > 18))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Or( + UnresolvedFunction( + "like", + Seq( + UnresolvedFunction("lower", Seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("a%")), + isDistinct = false), + UnresolvedFunction( + "like", + Seq( + UnresolvedFunction("lower", Seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("j%")), + isDistinct = false)), + Or( + GreaterThan( + UnresolvedFunction("length", Seq(UnresolvedAttribute("state")), isDistinct = false), + Literal(6)), + And( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + GreaterThan(UnresolvedAttribute("age"), Literal(18))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex age ranges with nested conditions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (age BETWEEN 25 AND 40) AND ((state IN ('California', 'New York', 'Texas') AND year = 2023) OR (country != 'USA' AND (month = 1 OR month = 12)))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("age"), Literal(25)), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(40))), + Or( + And( + In( + UnresolvedAttribute("state"), + Seq(Literal("California"), Literal("New York"), Literal("Texas"))), + EqualTo(UnresolvedAttribute("year"), Literal(2023))), + And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + Or( + EqualTo(UnresolvedAttribute("month"), Literal(1)), + EqualTo(UnresolvedAttribute("month"), Literal(12)))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test nested NOT conditions") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE NOT (age < 18 OR (state = 'Alaska' AND year < 2020)) AND (country = 'USA' OR (country = 'Mexico' AND month BETWEEN 6 AND 8))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Not( + Or( + LessThan(UnresolvedAttribute("age"), Literal(18)), + And( + EqualTo(UnresolvedAttribute("state"), Literal("Alaska")), + LessThan(UnresolvedAttribute("year"), Literal(2020))))), + Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + And( + EqualTo(UnresolvedAttribute("country"), Literal("Mexico")), + And( + GreaterThanOrEqual(UnresolvedAttribute("month"), Literal(6)), + LessThanOrEqual(UnresolvedAttribute("month"), Literal(8)))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex boolean logic") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | WHERE (NOT (year < 2020 OR age < 18)) AND ((state = 'Texas' AND month % 2 = 0) OR (country = 'Mexico' AND (year = 2023 OR (year = 2022 AND month > 6))))"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = Filter( + And( + Not( + Or( + LessThan(UnresolvedAttribute("year"), Literal(2020)), + LessThan(UnresolvedAttribute("age"), Literal(18)))), + Or( + And( + EqualTo(UnresolvedAttribute("state"), Literal("Texas")), + EqualTo( + UnresolvedFunction( + "%", + Seq(UnresolvedAttribute("month"), Literal(2)), + isDistinct = false), + Literal(0))), + And( + EqualTo(UnresolvedAttribute("country"), Literal("Mexico")), + Or( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + And( + EqualTo(UnresolvedAttribute("year"), Literal(2022)), + GreaterThan(UnresolvedAttribute("month"), Literal(6))))))), + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}