diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 8e796c6fb..9ca6cf258 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -55,6 +55,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where isempty(a)` - `source = table | where isblank(a)` - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'` +- `source = table | where a not in (1, 2, 3) | fields a,b,c` ```sql source = table | eval status_category = diff --git a/docs/ppl-lang/functions/ppl-expressions.md b/docs/ppl-lang/functions/ppl-expressions.md index 6315663c2..171f97385 100644 --- a/docs/ppl-lang/functions/ppl-expressions.md +++ b/docs/ppl-lang/functions/ppl-expressions.md @@ -127,7 +127,7 @@ OR operator : NOT operator : - os> source=accounts | where not age in (32, 33) | fields age ; + os> source=accounts | where age not in (32, 33) | fields age ; fetched rows / total rows = 2/2 +-------+ | age | diff --git a/docs/ppl-lang/ppl-eval-command.md b/docs/ppl-lang/ppl-eval-command.md index cd0898c1b..1908c087c 100644 --- a/docs/ppl-lang/ppl-eval-command.md +++ b/docs/ppl-lang/ppl-eval-command.md @@ -80,6 +80,8 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` +- `source = table | eval f = a in ('foo', 'bar') | fields f` +- `source = table | eval f = a not in ('foo', 'bar') | fields f` Eval with `case` example: ```sql diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index e10b2e2a6..c3dd1d533 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, In, LessThan, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest @@ -688,4 +688,20 @@ class FlintSparkPPLEvalITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test IN expr in eval") { + val frame = sql(s""" + | source = $testTable | eval in = state in ('California', 'New York') | fields in + | """.stripMargin) + assertSameRows(Seq(Row(true), Row(true), Row(false), Row(false)), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val in = Alias( + In(UnresolvedAttribute("state"), Seq(Literal("California"), Literal("New York"))), + "in")() + val eval = Project(Seq(UnresolvedStar(None), in), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("in")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 14ef7ccc4..f2d7ee844 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, In, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -453,4 +453,18 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test NOT IN expr in filter") { + val frame = sql(s""" + | source = $testTable | where state not in ('California', 'New York') | fields state + | """.stripMargin) + assertSameRows(Seq(Row("Ontario"), Row("Quebec")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val in = In(UnresolvedAttribute("state"), Seq(Literal("California"), Literal("New York"))) + val filter = Filter(Not(in), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("state")), filter) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 2d9357ec5..1b8acac1b 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -385,7 +385,7 @@ logicalExpression comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr - | valueExpression IN valueList # inExpr + | valueExpression NOT? IN valueList # inExpr ; valueExpressionList @@ -1028,6 +1028,7 @@ keywordsCanBeId | ML | EXPLAIN // commands assist keywords + | IN | SOURCE | INDEX | DESC diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 32ed2c92f..5f5db5a5a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -15,6 +15,7 @@ import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; @@ -769,7 +770,13 @@ public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { @Override public Expression visitIn(In node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : In"); + node.getField().accept(this, context); + Expression value = context.popNamedParseExpressions().get(); + List list = node.getValueList().stream().map( expression -> { + expression.accept(this, context); + return context.popNamedParseExpressions().get(); + }).collect(Collectors.toList()); + return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list))); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index ea51ca7a1..fa03fc856 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -24,6 +24,7 @@ import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; @@ -418,6 +419,13 @@ public UnresolvedExpression visitExistsSubqueryExpr(OpenSearchPPLParser.ExistsSu return new ExistsSubquery(astBuilder.visitSubSearch(ctx.subSearch())); } + @Override + public UnresolvedExpression visitInExpr(OpenSearchPPLParser.InExprContext ctx) { + UnresolvedExpression expr = new In(visit(ctx.valueExpression()), + ctx.valueList().literalValue().stream().map(this::visit).collect(Collectors.toList())); + return ctx.NOT() != null ? new Not(expr) : expr; + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala index 3e2b3cc30..ba0d78670 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, In, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -200,4 +200,17 @@ class PPLLogicalPlanEvalTranslatorTestSuite val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + + test("test IN expr in eval") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval in = a in ('Hello', 'World') | fields in"), + context) + + val in = Alias(In(UnresolvedAttribute("a"), Seq(Literal("Hello"), Literal("World"))), "in")() + val eval = Project(Seq(UnresolvedStar(None), in), UnresolvedRelation(Seq("t"))) + val expectedPlan = Project(Seq(UnresolvedAttribute("in")), eval) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala index 20809db95..fe9304e22 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -233,4 +233,15 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) comparePlans(expectedPlan, logPlan, false) } + + test("test IN expr in filter") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | where a in ('Hello', 'World')"), context) + + val in = In(UnresolvedAttribute("a"), Seq(Literal("Hello"), Literal("World"))) + val filter = Filter(in, UnresolvedRelation(Seq("t"))) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } }