Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lambda and add related array functions #864

Merged
merged 13 commits into from
Nov 5, 2024
7 changes: 4 additions & 3 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ RT_SQR_PRTHS: ']';
SINGLE_QUOTE: '\'';
DOUBLE_QUOTE: '"';
BACKTICK: '`';
ARROW: '->';

// Operators. Bit

Expand Down Expand Up @@ -380,9 +381,9 @@ JSON_VALID: 'JSON_VALID';
//JSON_DELETE: 'JSON_DELETE';
//JSON_EXTEND: 'JSON_EXTEND';
//JSON_SET: 'JSON_SET';
//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH';
//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH';
//JSON_ARRAY_FILTER: 'JSON_FILTER';
JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH';
JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH';
JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER';
//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP';
//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ valueExpression
| timestampFunction # timestampFunctionCall
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
| LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr
| ident ARROW expression #lambda
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
| LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression #lambda
;

primaryExpression
Expand Down Expand Up @@ -851,9 +853,9 @@ jsonFunctionName
// | JSON_DELETE
// | JSON_EXTEND
// | JSON_SET
// | JSON_ARRAY_ALL_MATCH
// | JSON_ARRAY_ANY_MATCH
// | JSON_ARRAY_FILTER
| JSON_ARRAY_ALL_MATCH
| JSON_ARRAY_ANY_MATCH
| JSON_ARRAY_FILTER
// | JSON_ARRAY_MAP
// | JSON_ARRAY_REDUCE
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldList;
import org.opensearch.sql.ast.expression.PPLLambdaFunction;
import org.opensearch.sql.ast.tree.FieldSummary;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -179,6 +180,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitPPLLambdaFunction(PPLLambdaFunction node, C context) {
return visitChildren(node, context);
}

public T visitIsEmpty(IsEmpty node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of lambda function. Params include function name (@funcName) and function
* arguments (@funcArgs)
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class PPLLambdaFunction extends UnresolvedExpression {
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
private final UnresolvedExpression function;
private final List<QualifiedName> funcArgs;

@Override
public List<UnresolvedExpression> getChild() {
List<UnresolvedExpression> children = new ArrayList<>();
children.add(function);
children.addAll(funcArgs);
return children;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitPPLLambdaFunction(this, context);
}

@Override
public String toString() {
return String.format(
"(%s) -> %s",
funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")),
function.toString()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ public enum BuiltinFunctionName {
// JSON_APPEND(FunctionName.of("json_append")),
// JSON_EXTEND(FunctionName.of("json_extend")),
// JSON_SET(FunctionName.of("json_set")),
// JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")),
// JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")),
// JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")),
JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")),
JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")),
JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")),
// JSON_ARRAY_MAP(FunctionName.of("json_array_map")),
// JSON_ARRAY_REDUCE(FunctionName.of("json_array_reduce")),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
Expand All @@ -14,13 +15,16 @@
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.MakeInterval$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.DataTypes;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand All @@ -40,6 +44,7 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.PPLLambdaFunction;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
Expand All @@ -59,7 +64,9 @@
import org.opensearch.sql.ppl.utils.AggregatorTransformer;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import org.opensearch.sql.ppl.utils.JavaToScalaTransformer;
import scala.Option;
import scala.PartialFunction;
import scala.Tuple2;
import scala.collection.Seq;

Expand Down Expand Up @@ -423,6 +430,25 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys
return context.getNamedParseExpressions().push(udf);
}

@Override
public Expression visitPPLLambdaFunction(PPLLambdaFunction node, CatalystPlanContext context) {
PartialFunction<Expression, Expression> transformer = JavaToScalaTransformer.toPartialFunction(
expr -> expr instanceof UnresolvedAttribute,
expr -> {
UnresolvedAttribute attr = (UnresolvedAttribute) expr;
return new UnresolvedNamedLambdaVariable(attr.nameParts());
}
);
Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer);
context.popNamedParseExpressions();
List<NamedExpression> argsResult = node.getFuncArgs().stream()
.map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts())))
.collect(Collectors.toList());
LambdaFunction lambdaFunction = new LambdaFunction(functionResult, seq(argsResult), false);
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
context.getNamedParseExpressions().push(lambdaFunction);
return lambdaFunction;
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.IsEmpty;
import org.opensearch.sql.ast.expression.PPLLambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
Expand Down Expand Up @@ -427,6 +428,15 @@ public UnresolvedExpression visitTimestampFunctionCall(
ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx));
}

@Override
public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) {

List<QualifiedName> arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect(
Collectors.toList());
UnresolvedExpression function = visitExpression(ctx.expression());
return new PPLLambdaFunction(function, arguments);
}

private List<UnresolvedExpression> timestampFunctionArguments(
OpenSearchPPLParser.TimestampFunctionCallContext ctx) {
List<UnresolvedExpression> args =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction$;
import org.apache.spark.sql.catalyst.expressions.ArrayExists$;
import org.apache.spark.sql.catalyst.expressions.ArrayFilter$;
import org.apache.spark.sql.catalyst.expressions.ArrayForAll$;
import org.apache.spark.sql.catalyst.expressions.CurrentTimeZone$;
import org.apache.spark.sql.catalyst.expressions.CurrentTimestamp$;
import org.apache.spark.sql.catalyst.expressions.DateAddInterval$;
Expand Down Expand Up @@ -35,6 +38,9 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_ALL_MATCH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_ANY_MATCH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_FILTER;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_LENGTH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_EXTRACT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_KEYS;
Expand Down Expand Up @@ -179,6 +185,18 @@ public interface BuiltinFunctionTransformer {
args -> {
return ToUTCTimestamp$.MODULE$.apply(CurrentTimestamp$.MODULE$.apply(), CurrentTimeZone$.MODULE$.apply());
})
.put(
JSON_ARRAY_ALL_MATCH,
args -> ArrayForAll$.MODULE$.apply(args.get(0), args.get(1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we accept both ARRAY and JSON ARRAY STRING?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Em, I think we can keep this lambda functions and remove the JSON_ARRAY_ prefix since they seem common functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use the original function name as spark

)
.put(
JSON_ARRAY_ANY_MATCH,
args -> ArrayExists$.MODULE$.apply(args.get(0), args.get(1))
)
.put(
JSON_ARRAY_FILTER,
args -> ArrayFilter$.MODULE$.apply(args.get(0), args.get(1))
)
.build();

static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List<Expression> args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ package org.opensearch.flint.spark.ppl

import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayExists, ArrayFilter, ArrayForAll, EqualTo, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}

Expand Down Expand Up @@ -230,4 +231,70 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite
val expectedPlan = Project(projectList, filterPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test json_array_all_match()") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser,
"""source=t | eval a = json_array(1, 2, 3), b = json_array_all_match(a, x -> x > 0)""".stripMargin),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a test for a = json_array(json_object(..), json_object(..), json_object(..)).
For example:

a = [
  {id:8, uid:1},
  {id:7, uid:9},
  {id:7, uid:1},
  ...
]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in IT

Copy link
Member

@LantaoJin LantaoJin Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a = [
  {id:8, uid:1},
  {id:7, uid:9},
  {id:7, uid:1},
  ...
]

Could you add this json structure example to user doc ppl-lambda.md? it could help user to understand the case.

(PS: please keep review conversations open, it could help reviewers to remember the context and join conversation)

context)
val table = UnresolvedRelation(Seq("t"))
val jsonFunc =
UnresolvedFunction(
"array",
Seq(Literal(1), Literal(2), Literal(3)),
isDistinct = false)
val aliasA = Alias(jsonFunc, "a")()
val lambda = LambdaFunction(GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val aliasB = Alias(ArrayForAll(UnresolvedAttribute("a"), lambda), "b")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}

test("test json_array_any_match()") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser,
"""source=t | eval a = json_array(1, 2, 3), b = json_array_any_match(a, x -> x > 0)""".stripMargin),
context)
val table = UnresolvedRelation(Seq("t"))
val jsonFunc =
UnresolvedFunction(
"array",
Seq(Literal(1), Literal(2), Literal(3)),
isDistinct = false)
val aliasA = Alias(jsonFunc, "a")()
val lambda = LambdaFunction(GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val aliasB = Alias(ArrayExists(UnresolvedAttribute("a"), lambda), "b")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}

test("test json_array_filter()") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser,
"""source=t | eval a = json_array(1, 2, 3), b = json_array_filter(a, x -> x > 0)""".stripMargin),
context)
val table = UnresolvedRelation(Seq("t"))
val jsonFunc =
UnresolvedFunction(
"array",
Seq(Literal(1), Literal(2), Literal(3)),
isDistinct = false)
val aliasA = Alias(jsonFunc, "a")()
val lambda = LambdaFunction(GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val aliasB = Alias(ArrayFilter(UnresolvedAttribute("a"), lambda), "b")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}
}
Loading