Skip to content

Commit

Permalink
Support IN expression in PPL (opensearch-project#823)
Browse files Browse the repository at this point in the history
* Support IN expression in PPL

Signed-off-by: Lantao Jin <[email protected]>

* add more example in eval doc

Signed-off-by: Lantao Jin <[email protected]>

---------

Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin authored Oct 29, 2024
1 parent 380fd50 commit ce50567
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp
- `source = table | where isempty(a)`
- `source = table | where isblank(a)`
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`
- `source = table | where a not in (1, 2, 3) | fields a,b,c`

```sql
source = table | eval status_category =
Expand Down
2 changes: 1 addition & 1 deletion docs/ppl-lang/functions/ppl-expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ OR operator :
NOT operator :
os> source=accounts | where not age in (32, 33) | fields age ;
os> source=accounts | where age not in (32, 33) | fields age ;
fetched rows / total rows = 2/2
+-------+
| age |
Expand Down
2 changes: 2 additions & 0 deletions docs/ppl-lang/ppl-eval-command.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))`
- `source = table | eval f = a in ('foo', 'bar') | fields f`
- `source = table | eval f = a not in ('foo', 'bar') | fields f`

Eval with `case` example:
```sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, In, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -688,4 +688,20 @@ class FlintSparkPPLEvalITSuite
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}

test("test IN expr in eval") {
val frame = sql(s"""
| source = $testTable | eval in = state in ('California', 'New York') | fields in
| """.stripMargin)
assertSameRows(Seq(Row(true), Row(true), Row(false), Row(false)), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val in = Alias(
In(UnresolvedAttribute("state"), Seq(Literal("California"), Literal("New York"))),
"in")()
val eval = Project(Seq(UnresolvedStar(None), in), table)
val expectedPlan = Project(Seq(UnresolvedAttribute("in")), eval)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, In, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -453,4 +453,18 @@ class FlintSparkPPLFiltersITSuite
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test NOT IN expr in filter") {
val frame = sql(s"""
| source = $testTable | where state not in ('California', 'New York') | fields state
| """.stripMargin)
assertSameRows(Seq(Row("Ontario"), Row("Quebec")), frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val in = In(UnresolvedAttribute("state"), Seq(Literal("California"), Literal("New York")))
val filter = Filter(Not(in), table)
val expectedPlan = Project(Seq(UnresolvedAttribute("state")), filter)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
3 changes: 2 additions & 1 deletion ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ logicalExpression

comparisonExpression
: left = valueExpression comparisonOperator right = valueExpression # compareExpr
| valueExpression IN valueList # inExpr
| valueExpression NOT? IN valueList # inExpr
;

valueExpressionList
Expand Down Expand Up @@ -1028,6 +1028,7 @@ keywordsCanBeId
| ML
| EXPLAIN
// commands assist keywords
| IN
| SOURCE
| INDEX
| DESC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
Expand Down Expand Up @@ -765,7 +766,13 @@ public Expression visitDedupe(Dedupe node, CatalystPlanContext context) {

@Override
public Expression visitIn(In node, CatalystPlanContext context) {
throw new IllegalStateException("Not Supported operation : In");
node.getField().accept(this, context);
Expression value = context.popNamedParseExpressions().get();
List<Expression> list = node.getValueList().stream().map( expression -> {
expression.accept(this, context);
return context.popNamedParseExpressions().get();
}).collect(Collectors.toList());
return context.getNamedParseExpressions().push(In$.MODULE$.apply(value, seq(list)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldList;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.expression.Interval;
Expand Down Expand Up @@ -418,6 +419,13 @@ public UnresolvedExpression visitExistsSubqueryExpr(OpenSearchPPLParser.ExistsSu
return new ExistsSubquery(astBuilder.visitSubSearch(ctx.subSearch()));
}

@Override
public UnresolvedExpression visitInExpr(OpenSearchPPLParser.InExprContext ctx) {
UnresolvedExpression expr = new In(visit(ctx.valueExpression()),
ctx.valueList().literalValue().stream().map(this::visit).collect(Collectors.toList()));
return ctx.NOT() != null ? new Not(expr) : expr;
}

private QualifiedName visitIdentifiers(List<? extends ParserRuleContext> ctx) {
return new QualifiedName(
ctx.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, In, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort}

Expand Down Expand Up @@ -200,4 +200,17 @@ class PPLLogicalPlanEvalTranslatorTestSuite
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t")))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test IN expr in eval") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | eval in = a in ('Hello', 'World') | fields in"),
context)

val in = Alias(In(UnresolvedAttribute("a"), Seq(Literal("Hello"), Literal("World"))), "in")()
val eval = Project(Seq(UnresolvedStar(None), in), UnresolvedRelation(Seq("t")))
val expectedPlan = Project(Seq(UnresolvedAttribute("in")), eval)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

Expand Down Expand Up @@ -233,4 +233,15 @@ class PPLLogicalPlanFiltersTranslatorTestSuite
val expectedPlan = Project(Seq(UnresolvedStar(None)), filter)
comparePlans(expectedPlan, logPlan, false)
}

test("test IN expr in filter") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=t | where a in ('Hello', 'World')"), context)

val in = In(UnresolvedAttribute("a"), Seq(Literal("Hello"), Literal("World")))
val filter = Filter(in, UnresolvedRelation(Seq("t")))
val expectedPlan = Project(Seq(UnresolvedStar(None)), filter)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}

0 comments on commit ce50567

Please sign in to comment.