From 9dc17ba1d85eaa77a245b4f6089402ad37c9c3bb Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 25 Jul 2024 00:10:09 +0800 Subject: [PATCH] IF function should support complex predicates in PPL (#2756) Signed-off-by: Lantao Jin --- docs/user/ppl/functions/condition.rst | 11 ++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 16 ++- .../opensearch/sql/ppl/parser/AstBuilder.java | 3 +- .../sql/ppl/parser/AstExpressionBuilder.java | 2 +- .../ppl/parser/AstExpressionBuilderTest.java | 114 ++++++++++++++++++ .../sql/sql/antlr/SQLSyntaxParserTest.java | 16 +++ 6 files changed, 154 insertions(+), 8 deletions(-) diff --git a/docs/user/ppl/functions/condition.rst b/docs/user/ppl/functions/condition.rst index fea76bedda..e48d4cb75c 100644 --- a/docs/user/ppl/functions/condition.rst +++ b/docs/user/ppl/functions/condition.rst @@ -181,3 +181,14 @@ Example:: | Bates | Nanette | Bates | | Adams | Dale | Adams | +----------+-------------+------------+ + + os> source=accounts | eval is_vip = if(age > 30 AND isnotnull(employer), true, false) | fields is_vip, firstname, lastname + fetched rows / total rows = 4/4 + +----------+-------------+------------+ + | is_vip | firstname | lastname | + |----------+-------------+------------| + | True | Amber | Duke | + | True | Hattie | Bond | + | False | Nanette | Bates | + | False | Dale | Adams | + +----------+-------------+------------+ diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 39fb7f53a6..4dc223b028 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -255,6 +255,7 @@ expression | valueExpression ; +// predicates logicalExpression : comparisonExpression # comparsion | NOT logicalExpression # logicalNot @@ -362,7 +363,7 @@ dataTypeFunctionCall // boolean functions booleanFunctionCall - : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS + : conditionFunctionName LT_PRTHS functionArgs RT_PRTHS ; convertedDataType @@ -382,7 +383,8 @@ evalFunctionName : mathematicalFunctionName | dateTimeFunctionName | textFunctionName - | conditionFunctionBase + | conditionFunctionName + | flowControlFunctionName | systemFunctionName | positionFunctionName ; @@ -392,7 +394,7 @@ functionArgs ; functionArg - : (ident EQUAL)? valueExpression + : (ident EQUAL)? expression ; relevanceArg @@ -623,11 +625,15 @@ timestampFunctionName ; // condition function return boolean value -conditionFunctionBase +conditionFunctionName : LIKE - | IF | ISNULL | ISNOTNULL + ; + +// flow control function return non-boolean value +flowControlFunctionName + : IF | IFNULL | NULLIF ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 3c693fa0bd..78fe28b49e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -332,8 +332,7 @@ public UnresolvedPlan visitTableFunction(TableFunctionContext ctx) { arg -> { String argName = (arg.ident() != null) ? arg.ident().getText() : null; builder.add( - new UnresolvedArgument( - argName, this.internalVisitExpression(arg.valueExpression()))); + new UnresolvedArgument(argName, this.internalVisitExpression(arg.expression()))); }); return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index f36765d3d7..aec22ac231 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -187,7 +187,7 @@ public UnresolvedExpression visitTakeAggFunctionCall( /** Eval function. */ @Override public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext ctx) { - final String functionName = ctx.conditionFunctionBase().getText().toLowerCase(); + final String functionName = ctx.conditionFunctionName().getText().toLowerCase(); return buildFunction( FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), ctx.functionArgs().functionArg()); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index de230a1fee..fbb25549ab 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -138,6 +138,120 @@ public void testEvalFunctionExprNoArgs() { assertEqual("source=t | eval f=PI()", eval(relation("t"), let(field("f"), function("PI")))); } + @Test + public void testEvalIfFunctionExpr() { + assertEqual( + "source=t | eval f=if(true, 1, 0)", + eval( + relation("t"), + let(field("f"), function("if", booleanLiteral(true), intLiteral(1), intLiteral(0))))); + assertEqual( + "source=t | eval f=if(1>2, 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + compare(">", intLiteral(1), intLiteral(2)), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(1<=2, 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + compare("<=", intLiteral(1), intLiteral(2)), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(1=2, 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + compare("=", intLiteral(1), intLiteral(2)), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(1!=2, 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + compare("!=", intLiteral(1), intLiteral(2)), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(isnull(a), 1, 0)", + eval( + relation("t"), + let( + field("f"), + function("if", function("is null", field("a")), intLiteral(1), intLiteral(0))))); + assertEqual( + "source=t | eval f=if(isnotnull(a), 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", function("is not null", field("a")), intLiteral(1), intLiteral(0))))); + assertEqual( + "source=t | eval f=if(not 1>2, 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + not(compare(">", intLiteral(1), intLiteral(2))), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(not a in (0, 1), 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + not(in(field("a"), intLiteral(0), intLiteral(1))), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(not a in (0, 1) OR isnull(a), 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + or( + not(in(field("a"), intLiteral(0), intLiteral(1))), + function("is null", field("a"))), + intLiteral(1), + intLiteral(0))))); + assertEqual( + "source=t | eval f=if(like(a, '_a%b%c_d_'), 1, 0)", + eval( + relation("t"), + let( + field("f"), + function( + "if", + function("like", field("a"), stringLiteral("_a%b%c_d_")), + intLiteral(1), + intLiteral(0))))); + } + @Test public void testPositionFunctionExpr() { assertEqual( diff --git a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java index f68c27deea..c43044508b 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java @@ -719,6 +719,22 @@ public void canParseMultiMatchAlternateSyntax() { assertNotNull(parser.parse("SELECT * FROM test WHERE Field = multimatch(\"query\")")); } + @Test + public void canParseIfFunction() { + assertNotNull(parser.parse("SELECT IF(1 > 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 < 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 >= 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 <= 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 <> 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 != 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 = 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(true, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(1 IS NOT NULL, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(NOT 1 > 2, 1, 0)")); + assertNotNull(parser.parse("SELECT IF(NOT 1 IN (0, 1), 1, 0)")); + assertNotNull(parser.parse("SELECT IF(NOT 1 IN (0, 1) OR 1 IS NOT NULL, 1, 0)")); + } + private static Stream matchPhraseQueryComplexQueries() { return Stream.of( "SELECT * FROM t WHERE matchphrasequery(c, 3)",