diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index 2541b3743..ea7b7d2dc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -5,17 +5,161 @@ package org.opensearch.sql.expression.function; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.module.scala.DefaultScalaModule; import inet.ipaddr.AddressStringException; import inet.ipaddr.IPAddressString; import inet.ipaddr.IPAddressStringParameters; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.types.DataTypes; import scala.Function2; +import scala.Option; import scala.Serializable; import scala.runtime.AbstractFunction2; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + public interface SerializableUdf { - Function2 cidrFunction = new SerializableAbstractFunction2<>() { + ObjectMapper objectMapper = new ObjectMapper(); + + abstract class SerializableAbstractFunction2 extends AbstractFunction2 + implements Serializable { + } + + /** + * Remove specified keys from a JSON string. + * + * @param jsonStr The input JSON string. + * @param keysToRemove The list of keys to remove. + * @return A new JSON string without the specified keys. + */ + Function2, String> jsonDeleteFunction = new SerializableAbstractFunction2<>() { + @Override + public String apply(String jsonStr, List keysToRemove) { + if (jsonStr == null) { + return null; + } + try { + Map jsonMap = objectMapper.readValue(jsonStr, Map.class); + removeKeys(jsonMap, keysToRemove); + return objectMapper.writeValueAsString(jsonMap); + } catch (Exception e) { + return null; + } + } + + private void removeKeys(Map map, List keysToRemove) { + for (String key : keysToRemove) { + String[] keyParts = key.split("\\."); + Map currentMap = map; + for (int i = 0; i < keyParts.length - 1; i++) { + String currentKey = keyParts[i]; + if (currentMap.containsKey(currentKey) && currentMap.get(currentKey) instanceof Map) { + currentMap = (Map) currentMap.get(currentKey); + } else { + return; // Path not found, exit + } + } + // Remove the final key if it exists + currentMap.remove(keyParts[keyParts.length - 1]); + } + } + }; + + Function2>, String> jsonAppendFunction = new SerializableAbstractFunction2<>() { + + /** + * Append values to JSON arrays based on specified path-value pairs. + * + * @param jsonStr The input JSON string. + * @param pathValuePairs A list of path-value pairs to append. + * @return The updated JSON string. + */ + public String apply(String jsonStr, List> pathValuePairs) { + if (jsonStr == null) { + return null; + } + try { + Map jsonMap = objectMapper.readValue(jsonStr, Map.class); + + for (Map.Entry pathValuePair : pathValuePairs) { + String path = pathValuePair.getKey(); + String value = pathValuePair.getValue(); + + if (jsonMap.containsKey(path) && jsonMap.get(path) instanceof List) { + List existingList = (List) jsonMap.get(path); + // Append value to the end of the existing Scala List + existingList.add(value); + jsonMap.put(path, existingList); + } else if (jsonMap.containsKey(path)) { + // Ignore appending if the path is not an array + } else { + jsonMap.put(path, List.of(value)); + } + } + + return objectMapper.writeValueAsString(jsonMap); + } catch (Exception e) { + return null; // Return null if parsing fails + } + } + }; + + /** + * Extend JSON arrays with new values based on specified path-value pairs. + * + * @param jsonStr The input JSON string. + * @param pathValuePairs A list of path-value pairs to extend. + * @return The updated JSON string. + */ + Function2>>, String> jsonExtendFunction = new SerializableAbstractFunction2<>() { + + @Override + public String apply(String jsonStr, List>> pathValuePairs) { + if (jsonStr == null) { + return null; + } + try { + Map jsonMap = objectMapper.readValue(jsonStr, Map.class); + + for (Map.Entry> pathValuePair : pathValuePairs) { + String path = pathValuePair.getKey(); + List values = pathValuePair.getValue(); + + if (jsonMap.containsKey(path) && jsonMap.get(path) instanceof List) { + List existingList = (List) jsonMap.get(path); + existingList.addAll(values); + } else { + jsonMap.put(path, values); + } + } + + return objectMapper.writeValueAsString(jsonMap); + } catch (Exception e) { + return null; // Return null if parsing fails + } + } + }; + + /** + * Check if a key matches the given path expression. + * + * @param key The key to check. + * @param path The path expression (e.g., "a.b"). + * @return True if the key matches, false otherwise. + */ + private static boolean matchesKey(String key, String path) { + return key.equals(path) || key.startsWith(path + "."); + } + + Function2 cidrFunction = new SerializableAbstractFunction2<>() { IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() .allowEmpty(false) @@ -32,7 +176,7 @@ public Boolean apply(String ipAddress, String cidrBlock) { try { parsedIpAddress.validate(); } catch (AddressStringException e) { - throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + throw new RuntimeException("The given ipAddress '" + ipAddress + "' is invalid. It must be a valid IPv4 or IPv6 address. Error details: " + e.getMessage()); } IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions); @@ -40,18 +184,63 @@ public Boolean apply(String ipAddress, String cidrBlock) { try { parsedCidrBlock.validate(); } catch (AddressStringException e) { - throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage()); + throw new RuntimeException("The given cidrBlock '" + cidrBlock + "' is invalid. It must be a valid CIDR or netmask. Error details: " + e.getMessage()); } - if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { - throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6."); + if (parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { + throw new RuntimeException("The given ipAddress '" + ipAddress + "' and cidrBlock '" + cidrBlock + "' are not compatible. Both must be either IPv4 or IPv6."); } return parsedCidrBlock.contains(parsedIpAddress); } }; - abstract class SerializableAbstractFunction2 extends AbstractFunction2 - implements Serializable { + /** + * get the function reference according to its name + * + * @param funcName + * @return + */ + static ScalaUDF visit(String funcName, List expressions) { + switch (funcName) { + case "cidr": + return new ScalaUDF(cidrFunction, + DataTypes.BooleanType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + case "json_delete": + return new ScalaUDF(jsonDeleteFunction, + DataTypes.StringType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("json_delete"), + false, + true); + case "json_extend": + return new ScalaUDF(jsonExtendFunction, + DataTypes.StringType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("json_extend"), + false, + true); + case "json_append": + return new ScalaUDF(jsonAppendFunction, + DataTypes.StringType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("json_append"), + false, + true); + default: + return null; + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 35ac7ed47..bababe79e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -11,29 +11,21 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Cast$; -import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.LambdaFunction$; -import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.RowFrame$; -import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; -import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; -import org.apache.spark.sql.catalyst.expressions.WindowExpression; -import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -44,7 +36,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -68,9 +59,7 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN; -import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.SerializableUdf; import org.opensearch.sql.ppl.utils.AggregatorTransformer; import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; @@ -89,6 +78,7 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyList; +import static java.util.List.of; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs; @@ -438,17 +428,7 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys Expression ipAddressExpression = context.getNamedParseExpressions().pop(); analyze(node.getCidrBlock(), context); Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); - - ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddressExpression,cidrBlockExpression), - seq(), - Option.empty(), - Option.apply("cidr"), - false, - true); - - return context.getNamedParseExpressions().push(udf); + return context.getNamedParseExpressions().push(SerializableUdf.visit("cidr", of(ipAddressExpression,cidrBlockExpression))); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index 0a4f19b53..f73a1c491 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -13,12 +13,15 @@ import org.apache.spark.sql.catalyst.expressions.DateAddInterval$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.TimestampAdd$; import org.apache.spark.sql.catalyst.expressions.TimestampDiff$; import org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp$; import org.apache.spark.sql.catalyst.expressions.UnaryMinus$; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.CatalystPlanContext; import scala.Option; import java.util.Arrays; @@ -26,7 +29,6 @@ import java.util.Map; import java.util.function.Function; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT; @@ -76,7 +78,7 @@ public interface BuiltinFunctionTransformer { * This is only used for the built-in functions between PPL and Spark with different names. * If the built-in function names are the same in PPL and Spark, add it to {@link BuiltinFunctionName} only. */ - static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING + Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING = ImmutableMap.builder() // arithmetic operators .put(ADD, "+") @@ -117,7 +119,7 @@ public interface BuiltinFunctionTransformer { /** * The name mapping between PPL builtin functions to Spark builtin functions. */ - static final Map, Expression>> PPL_TO_SPARK_FUNC_MAPPING + Map, Expression>> PPL_TO_SPARK_FUNC_MAPPING = ImmutableMap., Expression>>builder() // json functions .put( @@ -176,9 +178,11 @@ public interface BuiltinFunctionTransformer { static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { - // TODO change it when UDF is supported - // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions - throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + ScalaUDF udf = SerializableUdf.visit(function.getFuncName(), args); + if(udf == null) { + throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + } + return udf; } else { BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.get(builtin); diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableIPUdfTest.java similarity index 87% rename from ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java rename to ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableIPUdfTest.java index 3d3940730..c11c832c3 100644 --- a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableIPUdfTest.java @@ -1,9 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ package org.opensearch.sql.expression.function; import org.junit.Assert; import org.junit.Test; -public class SerializableUdfTest { +import java.util.Arrays; +import java.util.Collections; + +import static java.util.Collections.singletonList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class SerializableIPUdfTest { @Test(expected = RuntimeException.class) public void cidrNullIpTest() { diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableJsonUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableJsonUdfTest.java new file mode 100644 index 000000000..2c7080622 --- /dev/null +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableJsonUdfTest.java @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.expression.function; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.singletonList; +import static org.apache.derby.vti.XmlVTI.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.sql.expression.function.SerializableUdf.jsonAppendFunction; +import static org.opensearch.sql.expression.function.SerializableUdf.jsonDeleteFunction; +import static org.opensearch.sql.expression.function.SerializableUdf.jsonExtendFunction; + +public class SerializableJsonUdfTest { + + @Test + public void testJsonDeleteFunctionRemoveSingleKey() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}"; + String expectedJson = "{\"key1\":\"value1\",\"key3\":\"value3\"}"; + String result = jsonDeleteFunction.apply(jsonStr, singletonList("key2")); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveNestedKey() { + // Correctly escape double quotes within the JSON string + String jsonStr = "{\"key1\":\"value1\",\"key2\":{ \"key3\":\"value3\",\"key4\":\"value4\" }}"; + String expectedJson = "{\"key1\":\"value1\",\"key2\":{\"key4\":\"value4\"}}"; + String result = jsonDeleteFunction.apply(jsonStr, singletonList("key2.key3")); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveSingleArrayedKey() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\",\"keyArray\":[\"value1\",\"value2\"]}"; + String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + String result = jsonDeleteFunction.apply(jsonStr, singletonList("keyArray")); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveMultipleKeys() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}"; + String expectedJson = "{\"key3\":\"value3\"}"; + String result = jsonDeleteFunction.apply(jsonStr, Arrays.asList("key1", "key2")); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveMultipleSomeAreNestedKeys() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":{ \"key3\":\"value3\",\"key4\":\"value4\" }}"; + String expectedJson = "{\"key2\":{\"key3\":\"value3\"}}"; + String result = jsonDeleteFunction.apply(jsonStr, Arrays.asList("key1", "key2.key4")); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionNoKeysRemoved() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + String result = jsonDeleteFunction.apply(jsonStr, Collections.emptyList()); + assertEquals(jsonStr, result); + } + + @Test + public void testJsonDeleteFunctionNullJson() { + String result = jsonDeleteFunction.apply(null, Collections.singletonList("key1")); + assertNull(result); + } + + @Test + public void testJsonDeleteFunctionInvalidJson() { + String invalidJson = "invalid_json"; + String result = jsonDeleteFunction.apply(invalidJson, Collections.singletonList("key1")); + assertNull(result); + } + + @Test + public void testJsonAppendFunctionAppendToExistingArray() { + String jsonStr = "{\"arrayKey\":[\"value1\",\"value2\"]}"; + String expectedJson = "{\"arrayKey\":[\"value1\",\"value2\",\"value3\"]}"; + Map.Entry pair = Map.entry("arrayKey", "value3"); + String result = jsonAppendFunction.apply(jsonStr, Collections.singletonList(pair)); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionAddNewArray() { + String jsonStr = "{\"key1\":\"value1\"}"; + String expectedJson = "{\"key1\":\"value1\",\"newArray\":[\"newValue\"]}"; + Map.Entry pair = Map.entry("newArray", "newValue"); + String result = jsonAppendFunction.apply(jsonStr, Collections.singletonList(pair)); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionIgnoreNonArrayKey() { + String jsonStr = "{\"key1\":\"value1\"}"; + String expectedJson = jsonStr; + Map.Entry pair = Map.entry("key1", "newValue"); + String result = jsonAppendFunction.apply(jsonStr, Collections.singletonList(pair)); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionMultipleAppends() { + String jsonStr = "{\"arrayKey\":[\"value1\"]}"; + String expectedJson = "{\"arrayKey\":[\"value1\",\"value2\",\"value3\"],\"newKey\":[\"newValue\"]}"; + List> pairs = Arrays.asList( + Map.entry("arrayKey", "value2"), + Map.entry("arrayKey", "value3"), + Map.entry("newKey", "newValue") + ); + String result = jsonAppendFunction.apply(jsonStr, pairs); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionNullJson() { + String result = jsonAppendFunction.apply(null, Collections.singletonList(Map.entry("key", "value"))); + assertNull(result); + } + + @Test + public void testJsonAppendFunctionInvalidJson() { + String invalidJson = "invalid_json"; + String result = jsonAppendFunction.apply(invalidJson, Collections.singletonList(Map.entry("key", "value"))); + assertNull(result); + } + + @Test + public void testJsonExtendFunctionWithExistingPath() { + String jsonStr = "{\"path1\": [\"value1\", \"value2\"]}"; + List>> pathValuePairs = new ArrayList<>(); + pathValuePairs.add(Map.entry("path1", asList("value3", "value4"))); + + String result = jsonExtendFunction.apply(jsonStr, pathValuePairs); + String expectedJson = "{\"path1\":[\"value1\",\"value2\",\"value3\",\"value4\"]}"; + + assertEquals(expectedJson, result); + } + + @Test + public void testJsonExtendFunctionWithNewPath() { + String jsonStr = "{\"path1\": [\"value1\"]}"; + List>> pathValuePairs = new ArrayList<>(); + pathValuePairs.add(Map.entry("path2", asList("newValue1", "newValue2"))); + + String result = jsonExtendFunction.apply(jsonStr, pathValuePairs); + String expectedJson = "{\"path1\":[\"value1\"],\"path2\":[\"newValue1\",\"newValue2\"]}"; + + assertEquals(expectedJson, result); + } + + @Test + public void testJsonExtendFunctionWithNullInput() { + String result = jsonExtendFunction.apply(null, Collections.emptyList()); + assertNull(result); + } + + @Test + public void testJsonExtendFunctionWithInvalidJson() { + String result = jsonExtendFunction.apply("invalid json", Collections.emptyList()); + assertNull(result); + } + + @Test + public void testJsonExtendFunctionWithNonArrayPath() { + String jsonStr = "{\"path1\":\"value1\"}"; + List>> pathValuePairs = new ArrayList<>(); + pathValuePairs.add(Map.entry("path1", asList("value2"))); + + String result = jsonExtendFunction.apply(jsonStr, pathValuePairs); + String expectedJson = "{\"path1\":[\"value2\"]}"; + + assertEquals(expectedJson, result); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index 6193bc43f..34d0133e0 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -8,12 +8,17 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} +import org.apache.spark.sql.types.DataTypes +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.expression.function.SerializableUdf.visit +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import java.util class PPLLogicalPlanJsonFunctionsTranslatorTestSuite extends SparkFunSuite @@ -185,6 +190,43 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test json_delete()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_delete('{"a":[{"b":1},{"c":2}]}', ["a.b"])"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val keysExpression = Literal("[a.b]") + val jsonObjExp = Literal("""{"a":[{"b":1},{"c":2}]}""") + val jsonFunc = visit("json_delete", util.List.of(jsonObjExp, keysExpression)) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + + + test("test json_append()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t a = json_append('{"a":[{"b":1},{"c":2}]}', 'a.b')"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val keysExpression = Literal("a.b") + val jsonObjExp = Literal("""{"a":[{"b":1},{"c":2}]}""") + val jsonFunc = visit("json_delete", util.List.of(jsonObjExp, keysExpression)) + val filterExpr = EqualTo(UnresolvedAttribute("a"), jsonFunc) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } + test("test json_keys()") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala index 213f201cc..d14b0fee1 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala @@ -10,13 +10,13 @@ import org.opensearch.sql.expression.function.SerializableUdf import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.DataTypes +import org.opensearch.sql.expression.function.SerializableUdf.visit class PPLLogicalPlanParseCidrmatchTestSuite extends SparkFunSuite @@ -41,15 +41,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(false)) val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(true)) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val expectedPlan = Project( Seq(UnresolvedStar(None)), @@ -71,15 +63,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(false)) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val expectedPlan = Project( Seq(UnresolvedStar(None)), @@ -100,15 +84,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val cidrExpression = Literal("2003:db8::/32") val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val expectedPlan = Project( Seq(UnresolvedAttribute("ip")), @@ -130,15 +106,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) val filterClause = Filter(filterIpv6, UnresolvedRelation(Seq("t"))) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val equalTo = EqualTo(Literal(true), cidr) val caseFunction = CaseWhen(Seq((equalTo, Literal("in"))), Literal("out"))