Skip to content

Commit

Permalink
json function enhancement
Browse files Browse the repository at this point in the history
Signed-off-by: Heng Qian <[email protected]>
  • Loading branch information
qianheng-aws committed Nov 1, 2024
1 parent 950009b commit 3978b19
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 10 deletions.
7 changes: 4 additions & 3 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ RT_SQR_PRTHS: ']';
SINGLE_QUOTE: '\'';
DOUBLE_QUOTE: '"';
BACKTICK: '`';
ARROW: '->';

// Operators. Bit

Expand Down Expand Up @@ -380,9 +381,9 @@ JSON_VALID: 'JSON_VALID';
//JSON_DELETE: 'JSON_DELETE';
//JSON_EXTEND: 'JSON_EXTEND';
//JSON_SET: 'JSON_SET';
//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH';
//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH';
//JSON_ARRAY_FILTER: 'JSON_FILTER';
JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH';
JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH';
JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER';
//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP';
//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE';

Expand Down
8 changes: 5 additions & 3 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ valueExpression
| timestampFunction # timestampFunctionCall
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
| LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr
| ident ARROW expression #lambda
| LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression #lambda
;

primaryExpression
Expand Down Expand Up @@ -851,9 +853,9 @@ jsonFunctionName
// | JSON_DELETE
// | JSON_EXTEND
// | JSON_SET
// | JSON_ARRAY_ALL_MATCH
// | JSON_ARRAY_ANY_MATCH
// | JSON_ARRAY_FILTER
| JSON_ARRAY_ALL_MATCH
| JSON_ARRAY_ANY_MATCH
| JSON_ARRAY_FILTER
// | JSON_ARRAY_MAP
// | JSON_ARRAY_REDUCE
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldList;
import org.opensearch.sql.ast.expression.PPLLambdaFunction;
import org.opensearch.sql.ast.tree.FieldSummary;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -179,6 +180,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitPPLLambdaFunction(PPLLambdaFunction node, C context) {
return visitChildren(node, context);
}

public T visitIsEmpty(IsEmpty node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of lambda function. Params include function name (@funcName) and function
* arguments (@funcArgs)
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class PPLLambdaFunction extends UnresolvedExpression {
private final UnresolvedExpression function;
private final List<QualifiedName> funcArgs;

@Override
public List<UnresolvedExpression> getChild() {
List<UnresolvedExpression> children = new ArrayList<>();
children.add(function);
children.addAll(funcArgs);
return children;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitPPLLambdaFunction(this, context);
}

@Override
public String toString() {
return String.format(
"(%s) -> %s",
funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")),
function.toString()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ public enum BuiltinFunctionName {
// JSON_APPEND(FunctionName.of("json_append")),
// JSON_EXTEND(FunctionName.of("json_extend")),
// JSON_SET(FunctionName.of("json_set")),
// JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")),
// JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")),
// JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")),
JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")),
JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")),
JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")),
// JSON_ARRAY_MAP(FunctionName.of("json_array_map")),
// JSON_ARRAY_REDUCE(FunctionName.of("json_array_reduce")),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
Expand All @@ -14,13 +15,16 @@
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.MakeInterval$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.DataTypes;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand All @@ -40,6 +44,7 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.PPLLambdaFunction;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
Expand All @@ -59,7 +64,9 @@
import org.opensearch.sql.ppl.utils.AggregatorTransformer;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import org.opensearch.sql.ppl.utils.JavaToScalaTransformer;
import scala.Option;
import scala.PartialFunction;
import scala.Tuple2;
import scala.collection.Seq;

Expand Down Expand Up @@ -423,6 +430,25 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys
return context.getNamedParseExpressions().push(udf);
}

@Override
public Expression visitPPLLambdaFunction(PPLLambdaFunction node, CatalystPlanContext context) {
PartialFunction<Expression, Expression> transformer = JavaToScalaTransformer.toPartialFunction(
expr -> expr instanceof UnresolvedAttribute,
expr -> {
UnresolvedAttribute attr = (UnresolvedAttribute) expr;
return new UnresolvedNamedLambdaVariable(attr.nameParts());
}
);
Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer);
context.popNamedParseExpressions();
List<NamedExpression> argsResult = node.getFuncArgs().stream()
.map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts())))
.collect(Collectors.toList());
LambdaFunction lambdaFunction = new LambdaFunction(functionResult, seq(argsResult), false);
context.getNamedParseExpressions().push(lambdaFunction);
return lambdaFunction;
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.IsEmpty;
import org.opensearch.sql.ast.expression.PPLLambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
Expand Down Expand Up @@ -427,6 +428,15 @@ public UnresolvedExpression visitTimestampFunctionCall(
ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx));
}

@Override
public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) {

List<QualifiedName> arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect(
Collectors.toList());
UnresolvedExpression function = visitExpression(ctx.expression());
return new PPLLambdaFunction(function, arguments);
}

private List<UnresolvedExpression> timestampFunctionArguments(
OpenSearchPPLParser.TimestampFunctionCallContext ctx) {
List<UnresolvedExpression> args =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction$;
import org.apache.spark.sql.catalyst.expressions.ArrayExists$;
import org.apache.spark.sql.catalyst.expressions.ArrayFilter$;
import org.apache.spark.sql.catalyst.expressions.ArrayForAll$;
import org.apache.spark.sql.catalyst.expressions.CurrentTimeZone$;
import org.apache.spark.sql.catalyst.expressions.CurrentTimestamp$;
import org.apache.spark.sql.catalyst.expressions.DateAddInterval$;
Expand Down Expand Up @@ -35,6 +38,9 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_ALL_MATCH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_ANY_MATCH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_FILTER;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY_LENGTH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_EXTRACT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_KEYS;
Expand Down Expand Up @@ -179,6 +185,18 @@ public interface BuiltinFunctionTransformer {
args -> {
return ToUTCTimestamp$.MODULE$.apply(CurrentTimestamp$.MODULE$.apply(), CurrentTimeZone$.MODULE$.apply());
})
.put(
JSON_ARRAY_ALL_MATCH,
args -> ArrayForAll$.MODULE$.apply(args.get(0), args.get(1))
)
.put(
JSON_ARRAY_ANY_MATCH,
args -> ArrayExists$.MODULE$.apply(args.get(0), args.get(1))
)
.put(
JSON_ARRAY_FILTER,
args -> ArrayFilter$.MODULE$.apply(args.get(0), args.get(1))
)
.build();

static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List<Expression> args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ package org.opensearch.flint.spark.ppl

import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayExists, ArrayFilter, ArrayForAll, EqualTo, GreaterThan, LambdaFunction, Literal, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}

Expand Down Expand Up @@ -230,4 +231,70 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite
val expectedPlan = Project(projectList, filterPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test json_array_all_match()") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser,
"""source=t | eval a = json_array(1, 2, 3), b = json_array_all_match(a, x -> x > 0)""".stripMargin),
context)
val table = UnresolvedRelation(Seq("t"))
val jsonFunc =
UnresolvedFunction(
"array",
Seq(Literal(1), Literal(2), Literal(3)),
isDistinct = false)
val aliasA = Alias(jsonFunc, "a")()
val lambda = LambdaFunction(GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val aliasB = Alias(ArrayForAll(UnresolvedAttribute("a"), lambda), "b")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}

test("test json_array_any_match()") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser,
"""source=t | eval a = json_array(1, 2, 3), b = json_array_any_match(a, x -> x > 0)""".stripMargin),
context)
val table = UnresolvedRelation(Seq("t"))
val jsonFunc =
UnresolvedFunction(
"array",
Seq(Literal(1), Literal(2), Literal(3)),
isDistinct = false)
val aliasA = Alias(jsonFunc, "a")()
val lambda = LambdaFunction(GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val aliasB = Alias(ArrayExists(UnresolvedAttribute("a"), lambda), "b")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}

test("test json_array_filter()") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser,
"""source=t | eval a = json_array(1, 2, 3), b = json_array_filter(a, x -> x > 0)""".stripMargin),
context)
val table = UnresolvedRelation(Seq("t"))
val jsonFunc =
UnresolvedFunction(
"array",
Seq(Literal(1), Literal(2), Literal(3)),
isDistinct = false)
val aliasA = Alias(jsonFunc, "a")()
val lambda = LambdaFunction(GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)), Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val aliasB = Alias(ArrayFilter(UnresolvedAttribute("a"), lambda), "b")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasA, aliasB), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit 3978b19

Please sign in to comment.