Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lambda and add related array functions #864

Merged
merged 13 commits into from
Nov 5, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, Not}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.functions.{col, to_json}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLJsonFunctionITSuite
Expand Down Expand Up @@ -383,4 +384,73 @@ class FlintSparkPPLJsonFunctionITSuite
null))
assertSameRows(expectedSeq, frame)
}

test("test forall()") {
val frame = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame)

val frame2 = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame2)

val frame3 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame3)

val frame4 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame4)

}

test("test exists()") {
val frame = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame)

val frame2 = sql(s"""
| source = $testTable | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame2)

val frame3 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.a > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(true)), frame3)

val frame4 = sql(s"""
| source = $testTable | eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = exists(array, x -> x.b > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(false)), frame4)

}

test("test filter()") {
val frame = sql(s"""
| source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(Seq(1, 2, 1.1))), frame)

val frame2 = sql(s"""
| source = $testTable| eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row(Seq())), frame2)

val frame3 = sql(s"""
| source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.a > 0) | head 1 | fields result
| """.stripMargin)

assertSameRows(Seq(Row("""[{"a":1,"b":-1}]""")), frame3.select(to_json(col("result"))))

val frame4 = sql(s"""
| source = $testTable| eval array = json_array(json_object("a",1,"b",-1),json_object("a",-1,"b",-1)), result = filter(array, x -> x.b > 0) | head 1 | fields result
| """.stripMargin)
assertSameRows(Seq(Row("""[]""")), frame4.select(to_json(col("result"))))
}
}
12 changes: 9 additions & 3 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ RT_SQR_PRTHS: ']';
SINGLE_QUOTE: '\'';
DOUBLE_QUOTE: '"';
BACKTICK: '`';
ARROW: '->';

// Operators. Bit

Expand Down Expand Up @@ -384,15 +385,20 @@ JSON_VALID: 'JSON_VALID';
//JSON_DELETE: 'JSON_DELETE';
//JSON_EXTEND: 'JSON_EXTEND';
//JSON_SET: 'JSON_SET';
//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH';
//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH';
//JSON_ARRAY_FILTER: 'JSON_FILTER';
//JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH';
//JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH';
//JSON_ARRAY_FILTER: 'JSON_ARRAY_FILTER';
//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP';
//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE';

// COLLECTION FUNCTIONS
ARRAY: 'ARRAY';

// LAMBDA FUNCTIONS
//EXISTS: 'EXISTS';
FORALL: 'FORALL';
FILTER: 'FILTER';

// BOOL FUNCTIONS
LIKE: 'LIKE';
ISNULL: 'ISNULL';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ valueExpression
| timestampFunction # timestampFunctionCall
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
| LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr
| ident ARROW expression #lambda
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
| LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression #lambda
;

primaryExpression
Expand Down Expand Up @@ -568,6 +570,7 @@ evalFunctionName
| cryptographicFunctionName
| jsonFunctionName
| collectionFunctionName
| lambdaFcuntionName
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
;

functionArgs
Expand Down Expand Up @@ -875,6 +878,12 @@ collectionFunctionName
: ARRAY
;

lambdaFcuntionName
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
: FORALL
| EXISTS
| FILTER
;

positionFunctionName
: POSITION
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldList;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.tree.FieldSummary;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -183,6 +184,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitPPLLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}

public T visitIsEmpty(IsEmpty node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of lambda function. Params include function name (@funcName) and function
* arguments (@funcArgs)
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class LambdaFunction extends UnresolvedExpression {
private final UnresolvedExpression function;
private final List<QualifiedName> funcArgs;

@Override
public List<UnresolvedExpression> getChild() {
List<UnresolvedExpression> children = new ArrayList<>();
children.add(function);
children.addAll(funcArgs);
return children;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitPPLLambdaFunction(this, context);
qianheng-aws marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public String toString() {
return String.format(
"(%s) -> %s",
funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")),
function.toString()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ public enum BuiltinFunctionName {
/** COLLECTION Functions **/
ARRAY(FunctionName.of("array")),

/** LAMBDA Functions **/
ARRAY_FORALL(FunctionName.of("forall")),
ARRAY_EXISTS(FunctionName.of("exists")),
ARRAY_FILTER(FunctionName.of("filter")),

/** NULL Test. */
IS_NULL(FunctionName.of("is null")),
IS_NOT_NULL(FunctionName.of("is not null")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
Expand All @@ -15,6 +16,7 @@
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction$;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
Expand All @@ -24,6 +26,8 @@
import org.apache.spark.sql.catalyst.expressions.RowFrame$;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$;
import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame;
import org.apache.spark.sql.catalyst.expressions.WindowExpression;
import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition;
Expand All @@ -47,6 +51,7 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
Expand All @@ -68,7 +73,9 @@
import org.opensearch.sql.ppl.utils.AggregatorTransformer;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import org.opensearch.sql.ppl.utils.JavaToScalaTransformer;
import scala.Option;
import scala.PartialFunction;
import scala.Tuple2;
import scala.collection.Seq;

Expand Down Expand Up @@ -432,6 +439,23 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys
return context.getNamedParseExpressions().push(udf);
}

@Override
public Expression visitPPLLambdaFunction(LambdaFunction node, CatalystPlanContext context) {
PartialFunction<Expression, Expression> transformer = JavaToScalaTransformer.toPartialFunction(
expr -> expr instanceof UnresolvedAttribute,
expr -> {
UnresolvedAttribute attr = (UnresolvedAttribute) expr;
return new UnresolvedNamedLambdaVariable(attr.nameParts());
}
);
Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer);
context.popNamedParseExpressions();
List<NamedExpression> argsResult = node.getFuncArgs().stream()
.map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts())))
.collect(Collectors.toList());
return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false));
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.IsEmpty;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
Expand Down Expand Up @@ -429,6 +430,15 @@ public UnresolvedExpression visitTimestampFunctionCall(
ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx));
}

@Override
public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) {

List<QualifiedName> arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect(
Collectors.toList());
UnresolvedExpression function = visitExpression(ctx.expression());
return new LambdaFunction(function, arguments);
}

private List<UnresolvedExpression> timestampFunctionArguments(
OpenSearchPPLParser.TimestampFunctionCallContext ctx) {
List<UnresolvedExpression> args =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl.utils;


import scala.PartialFunction;
import scala.runtime.AbstractPartialFunction;

public interface JavaToScalaTransformer {
static <T> PartialFunction<T, T> toPartialFunction(
java.util.function.Predicate<T> isDefinedAt,
java.util.function.Function<T, T> apply) {
return new AbstractPartialFunction<T, T>() {
@Override
public boolean isDefinedAt(T t) {
return isDefinedAt.test(t);
}

@Override
public T apply(T t) {
if (isDefinedAt.test(t)) return apply.apply(t);
else return t;
}
};
}
}
Loading
Loading