Skip to content

Commit

Permalink
add the following JSON functions & test for
Browse files Browse the repository at this point in the history
- json_delete
- json_append
- json_extend

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Dec 9, 2024
1 parent 2bde70b commit b7b0713
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,161 @@

package org.opensearch.sql.expression.function;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.module.scala.DefaultScalaModule;
import inet.ipaddr.AddressStringException;
import inet.ipaddr.IPAddressString;
import inet.ipaddr.IPAddressStringParameters;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.types.DataTypes;
import scala.Function2;
import scala.Option;
import scala.Serializable;
import scala.runtime.AbstractFunction2;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;


public interface SerializableUdf {

Function2<String,String,Boolean> cidrFunction = new SerializableAbstractFunction2<>() {
ObjectMapper objectMapper = new ObjectMapper();

abstract class SerializableAbstractFunction2<T1, T2, R> extends AbstractFunction2<T1, T2, R>
implements Serializable {
}

/**
* Remove specified keys from a JSON string.
*
* @param jsonStr The input JSON string.
* @param keysToRemove The list of keys to remove.
* @return A new JSON string without the specified keys.
*/
Function2<String, List<String>, String> jsonDeleteFunction = new SerializableAbstractFunction2<>() {
@Override
public String apply(String jsonStr, List<String> keysToRemove) {
if (jsonStr == null) {
return null;
}
try {
Map<String, Object> jsonMap = objectMapper.readValue(jsonStr, Map.class);
removeKeys(jsonMap, keysToRemove);
return objectMapper.writeValueAsString(jsonMap);
} catch (Exception e) {
return null;
}
}

private void removeKeys(Map<String, Object> map, List<String> keysToRemove) {
for (String key : keysToRemove) {
String[] keyParts = key.split("\\.");
Map<String, Object> currentMap = map;
for (int i = 0; i < keyParts.length - 1; i++) {
String currentKey = keyParts[i];
if (currentMap.containsKey(currentKey) && currentMap.get(currentKey) instanceof Map) {
currentMap = (Map<String, Object>) currentMap.get(currentKey);
} else {
return; // Path not found, exit
}
}
// Remove the final key if it exists
currentMap.remove(keyParts[keyParts.length - 1]);
}
}
};

Function2<String, List<Map.Entry<String, String>>, String> jsonAppendFunction = new SerializableAbstractFunction2<>() {

/**
* Append values to JSON arrays based on specified path-value pairs.
*
* @param jsonStr The input JSON string.
* @param pathValuePairs A list of path-value pairs to append.
* @return The updated JSON string.
*/
public String apply(String jsonStr, List<Map.Entry<String, String>> pathValuePairs) {
if (jsonStr == null) {
return null;
}
try {
Map<String, Object> jsonMap = objectMapper.readValue(jsonStr, Map.class);

for (Map.Entry<String, String> pathValuePair : pathValuePairs) {
String path = pathValuePair.getKey();
String value = pathValuePair.getValue();

if (jsonMap.containsKey(path) && jsonMap.get(path) instanceof List) {
List<Object> existingList = (List<Object>) jsonMap.get(path);
// Append value to the end of the existing Scala List
existingList.add(value);
jsonMap.put(path, existingList);
} else if (jsonMap.containsKey(path)) {
// Ignore appending if the path is not an array
} else {
jsonMap.put(path, List.of(value));
}
}

return objectMapper.writeValueAsString(jsonMap);
} catch (Exception e) {
return null; // Return null if parsing fails
}
}
};

/**
* Extend JSON arrays with new values based on specified path-value pairs.
*
* @param jsonStr The input JSON string.
* @param pathValuePairs A list of path-value pairs to extend.
* @return The updated JSON string.
*/
Function2<String, List<Map.Entry<String, List<String>>>, String> jsonExtendFunction = new SerializableAbstractFunction2<>() {

@Override
public String apply(String jsonStr, List<Map.Entry<String, List<String>>> pathValuePairs) {
if (jsonStr == null) {
return null;
}
try {
Map<String, Object> jsonMap = objectMapper.readValue(jsonStr, Map.class);

for (Map.Entry<String, List<String>> pathValuePair : pathValuePairs) {
String path = pathValuePair.getKey();
List<String> values = pathValuePair.getValue();

if (jsonMap.containsKey(path) && jsonMap.get(path) instanceof List) {
List<Object> existingList = (List<Object>) jsonMap.get(path);
existingList.addAll(values);
} else {
jsonMap.put(path, values);
}
}

return objectMapper.writeValueAsString(jsonMap);
} catch (Exception e) {
return null; // Return null if parsing fails
}
}
};

/**
* Check if a key matches the given path expression.
*
* @param key The key to check.
* @param path The path expression (e.g., "a.b").
* @return True if the key matches, false otherwise.
*/
private static boolean matchesKey(String key, String path) {
return key.equals(path) || key.startsWith(path + ".");
}

Function2<String, String, Boolean> cidrFunction = new SerializableAbstractFunction2<>() {

IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder()
.allowEmpty(false)
Expand All @@ -32,26 +176,71 @@ public Boolean apply(String ipAddress, String cidrBlock) {
try {
parsedIpAddress.validate();
} catch (AddressStringException e) {
throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage());
throw new RuntimeException("The given ipAddress '" + ipAddress + "' is invalid. It must be a valid IPv4 or IPv6 address. Error details: " + e.getMessage());
}

IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions);

try {
parsedCidrBlock.validate();
} catch (AddressStringException e) {
throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage());
throw new RuntimeException("The given cidrBlock '" + cidrBlock + "' is invalid. It must be a valid CIDR or netmask. Error details: " + e.getMessage());
}

if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) {
throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6.");
if (parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) {
throw new RuntimeException("The given ipAddress '" + ipAddress + "' and cidrBlock '" + cidrBlock + "' are not compatible. Both must be either IPv4 or IPv6.");
}

return parsedCidrBlock.contains(parsedIpAddress);
}
};

abstract class SerializableAbstractFunction2<T1,T2,R> extends AbstractFunction2<T1,T2,R>
implements Serializable {
/**
* get the function reference according to its name
*
* @param funcName
* @return
*/
static ScalaUDF visit(String funcName, List<Expression> expressions) {
switch (funcName) {
case "cidr":
return new ScalaUDF(cidrFunction,
DataTypes.BooleanType,
seq(expressions),
seq(),
Option.empty(),
Option.apply("cidr"),
false,
true);
case "json_delete":
return new ScalaUDF(jsonDeleteFunction,
DataTypes.StringType,
seq(expressions),
seq(),
Option.empty(),
Option.apply("json_delete"),
false,
true);
case "json_extend":
return new ScalaUDF(jsonExtendFunction,
DataTypes.StringType,
seq(expressions),
seq(),
Option.empty(),
Option.apply("json_extend"),
false,
true);
case "json_append":
return new ScalaUDF(jsonAppendFunction,
DataTypes.StringType,
seq(expressions),
seq(),
Option.empty(),
Option.apply("json_append"),
false,
true);
default:
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,21 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast$;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction$;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.MakeInterval$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.RowFrame$;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$;
import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame;
import org.apache.spark.sql.catalyst.expressions.WindowExpression;
import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.DataTypes;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
Expand All @@ -44,7 +36,6 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
Expand All @@ -68,9 +59,7 @@
import org.opensearch.sql.ast.tree.FillNull;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.SerializableUdf;
import org.opensearch.sql.ppl.utils.AggregatorTransformer;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer;
Expand All @@ -89,6 +78,7 @@
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.List.of;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL;
import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation;
import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs;
Expand Down Expand Up @@ -438,17 +428,7 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys
Expression ipAddressExpression = context.getNamedParseExpressions().pop();
analyze(node.getCidrBlock(), context);
Expression cidrBlockExpression = context.getNamedParseExpressions().pop();

ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction,
DataTypes.BooleanType,
seq(ipAddressExpression,cidrBlockExpression),
seq(),
Option.empty(),
Option.apply("cidr"),
false,
true);

return context.getNamedParseExpressions().push(udf);
return context.getNamedParseExpressions().push(SerializableUdf.visit("cidr", of(ipAddressExpression,cidrBlockExpression)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
import org.apache.spark.sql.catalyst.expressions.DateAddInterval$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.TimestampAdd$;
import org.apache.spark.sql.catalyst.expressions.TimestampDiff$;
import org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp$;
import org.apache.spark.sql.catalyst.expressions.UnaryMinus$;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.SerializableUdf;
import org.opensearch.sql.ppl.CatalystPlanContext;
import scala.Option;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT;
Expand Down Expand Up @@ -76,7 +78,7 @@ public interface BuiltinFunctionTransformer {
* This is only used for the built-in functions between PPL and Spark with different names.
* If the built-in function names are the same in PPL and Spark, add it to {@link BuiltinFunctionName} only.
*/
static final Map<BuiltinFunctionName, String> SPARK_BUILTIN_FUNCTION_NAME_MAPPING
Map<BuiltinFunctionName, String> SPARK_BUILTIN_FUNCTION_NAME_MAPPING
= ImmutableMap.<BuiltinFunctionName, String>builder()
// arithmetic operators
.put(ADD, "+")
Expand Down Expand Up @@ -117,7 +119,7 @@ public interface BuiltinFunctionTransformer {
/**
* The name mapping between PPL builtin functions to Spark builtin functions.
*/
static final Map<BuiltinFunctionName, Function<List<Expression>, Expression>> PPL_TO_SPARK_FUNC_MAPPING
Map<BuiltinFunctionName, Function<List<Expression>, Expression>> PPL_TO_SPARK_FUNC_MAPPING
= ImmutableMap.<BuiltinFunctionName, Function<List<Expression>, Expression>>builder()
// json functions
.put(
Expand Down Expand Up @@ -176,9 +178,11 @@ public interface BuiltinFunctionTransformer {

static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List<Expression> args) {
if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) {
// TODO change it when UDF is supported
// TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions
throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL");
ScalaUDF udf = SerializableUdf.visit(function.getFuncName(), args);
if(udf == null) {
throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL");
}
return udf;
} else {
BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get();
String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.get(builtin);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.sql.expression.function;

import org.junit.Assert;
import org.junit.Test;

public class SerializableUdfTest {
import java.util.Arrays;
import java.util.Collections;

import static java.util.Collections.singletonList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

public class SerializableIPUdfTest {

@Test(expected = RuntimeException.class)
public void cidrNullIpTest() {
Expand Down
Loading

0 comments on commit b7b0713

Please sign in to comment.