diff --git a/docs/ppl-lang/functions/ppl-json.md b/docs/ppl-lang/functions/ppl-json.md index 2c0c0ca67..3eb952cb7 100644 --- a/docs/ppl-lang/functions/ppl-json.md +++ b/docs/ppl-lang/functions/ppl-json.md @@ -203,6 +203,90 @@ Example: +----------------+ +### `JSON_DELETE` + +**Description** + +`json_delete(json_string, [keys list])` Deletes json elements from a json object based on json specific keys. Return the updated object after keys deletion . + +**Arguments type:** JSON_STRING, List + +**Return type:** JSON_STRING + +A JSON object format. + +Example: + + os> source=people | eval deleted = json_delete('{"account_number":1,"balance":39225,"age":32,"gender":"M"}', array('age','gender')) | head 1 | fields deleted + fetched rows / total rows = 1/1 + +------------------------------------------+ + | deleted | + +-----------------------------------------+ + |{"account_number":1,"balance":39225} | + +-----------------------------------------+ + + os> source=people | eval deleted = json_delete('{"f1":"abc","f2":{"f3":"a","f4":"b"}}', array('f2.f3')) | head 1 | fields deleted + fetched rows / total rows = 1/1 + +-----------------------------------------------------------+ + | deleted | + +-----------------------------------------------------------+ + | {"f1":"abc","f2":{"f4":"b"}} | + +-----------------------------------------------------------+ + + os> source=people | eval deleted = json_delete('{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}',array('teacher', 'student.rank')) | head 1 | fields deleted + fetched rows / total rows = 1/1 + +--------------------------------------------------+ + | deleted | + +--------------------------------------------------+ + |{"student":[{"name":"Bob"},{"name":"Charlie"}]} | + +--------------------------------------------------+ + +### `JSON_APPEND` + +**Description** + +`json_append(json_string, [path_key, list of values to add ])` appends values to end of an array within the json elements. Return the updated json object after appending . + +**Argument type:** JSON_STRING, List + +**Return type:** JSON_STRING + +A string JSON object format. + +**Note** +Append adds the value to the end of the existing array with the following cases: + - path is an object value - append is ignored and the value is returned + - path is an existing array not empty - the value are added to the array's tail + - path not found - the value are added to the root of the json tree + - path is an existing array is empty - create a new array with the given value + +Example: + + os> source=people | eval append = json_append(`{"teacher":["Alice"],"student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}`,array('student', '{"name":"Tomy","rank":5}')) | head 1 | fields append + fetched rows / total rows = 1/1 + +-----------------------------------------------------------------------------------------------------------------------------------+ + | append | + +-----------------------------------------------------------------------------------------------------------------------------------+ + |{"teacher":["Alice"],"student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2},{"name":"Tomy","rank":5}]} | + +-----------------------------------------------------------------------------------------------------------------------------------+ + + os> source=people | eval append = json_append(`{"teacher":["Alice"],"student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}`,array('teacher', 'Tom', 'Walt')) | head 1 | fields append + fetched rows / total rows = 1/1 + +-----------------------------------------------------------------------------------------------------------------------------------+ + | append | + +-----------------------------------------------------------------------------------------------------------------------------------+ + |{"teacher":["Alice","Tom","Walt"],"student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]} | + +-----------------------------------------------------------------------------------------------------------------------------------+ + + + os> source=people | eval append = json_append(`{"school":{"teacher":["Alice"],"student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}}`,array('school.teacher', 'Tom', 'Walt')) | head 1 | fields append + fetched rows / total rows = 1/1 + +-------------------------------------------------------------------------------------------------------------------------+ + | append | + +-------------------------------------------------------------------------------------------------------------------------+ + |{"school":{"teacher":["Alice","Tom","Walt"],"student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}} | + +-------------------------------------------------------------------------------------------------------------------------+ + ### `JSON_KEYS` **Description** diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala index fca758101..7a00d9a07 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -5,6 +5,10 @@ package org.opensearch.flint.spark.ppl +import java.util + +import org.opensearch.sql.expression.function.SerializableUdf.visit + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, Not} @@ -27,6 +31,11 @@ class FlintSparkPPLJsonFunctionITSuite private val validJson5 = "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}" private val validJson6 = "[1,2,3]" + private val validJson7 = + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}" + private val validJson8 = + "{\"school\":{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}}" + private val validJson9 = "{\"a\":[\"valueA\", \"valueB\"]}" private val invalidJson1 = "[1,2" private val invalidJson2 = "[invalid json]" private val invalidJson3 = "{\"invalid\": \"json\"" @@ -385,4 +394,278 @@ class FlintSparkPPLJsonFunctionITSuite null)) assertSameRows(expectedSeq, frame) } + + test("test json_delete() function: one key") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_delete('$validJson1',array('age')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("{\"account_number\":1,\"balance\":39225,\"gender\":\"M\"}")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = UnresolvedFunction("array", Seq(Literal("age")), isDistinct = false) + val jsonObjExp = + Literal("{\"account_number\":1,\"balance\":39225,\"age\":32,\"gender\":\"M\"}") + val jsonFunc = + Alias(visit("json_delete", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_delete() function: multiple keys") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_delete('$validJson1',array('age','gender')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("{\"account_number\":1,\"balance\":39225}")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction("array", Seq(Literal("age"), Literal("gender")), isDistinct = false) + val jsonObjExp = + Literal("{\"account_number\":1,\"balance\":39225,\"age\":32,\"gender\":\"M\"}") + val jsonFunc = + Alias(visit("json_delete", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_delete() function: nested key") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_delete('$validJson2',array('f2.f3')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("{\"f1\":\"abc\",\"f2\":{\"f4\":\"b\"}}")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction("array", Seq(Literal("f2.f3")), isDistinct = false) + val jsonObjExp = + Literal("{\"f1\":\"abc\",\"f2\":{\"f3\":\"a\",\"f4\":\"b\"}}") + val jsonFunc = + Alias(visit("json_delete", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_delete() function: multi depth keys ") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_delete('$validJson5',array('teacher', 'student.rank')) | head 1 | fields result + | """.stripMargin) + assertSameRows(Seq(Row("{\"student\":[{\"name\":\"Bob\"},{\"name\":\"Charlie\"}]}")), frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction( + "array", + Seq(Literal("teacher"), Literal("student.rank")), + isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_delete", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_delete() function: key not found") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_delete('$validJson5',array('none')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction("array", Seq(Literal("none")), isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_delete", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_append() function: add single value") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_append('$validJson7',array('teacher', 'Tom')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"teacher\":[\"Alice\",\"Tom\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction("array", Seq(Literal("teacher"), Literal("Tom")), isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_append() function: add single value key not found") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_append('$validJson7',array('headmaster', 'Tom')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}],\"headmaster\":[\"Tom\"]}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction("array", Seq(Literal("headmaster"), Literal("Tom")), isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_append() function: add single Object key not found") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_append('$validJson7',array('headmaster', '{"name":"Tomy","rank":1}')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}],\"headmaster\":[{\"name\":\"Tomy\",\"rank\":1}]}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction( + "array", + Seq(Literal("headmaster"), Literal("""{"name":"Tomy","rank":1}""")), + isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_append() function: add single Object value") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_append('$validJson7',array('student', '{"name":"Tomy","rank":5}')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2},{\"name\":\"Tomy\",\"rank\":5}]}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction( + "array", + Seq(Literal("student"), Literal("""{"name":"Tomy","rank":5}""")), + isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_append() function: add multi value") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_append('$validJson7',array('teacher', 'Tom', 'Walt')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"teacher\":[\"Alice\",\"Tom\",\"Walt\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction( + "array", + Seq(Literal("teacher"), Literal("Tom"), Literal("Walt")), + isDistinct = false) + val jsonObjExp = + Literal( + "{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_append() function: add nested value") { + val frame = sql(s""" + | source = $testTable + | | eval result = json_append('$validJson8',array('school.teacher', 'Tom', 'Walt')) | head 1 | fields result + | """.stripMargin) + assertSameRows( + Seq(Row( + "{\"school\":{\"teacher\":[\"Alice\",\"Tom\",\"Walt\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}}")), + frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val keysExpression = + UnresolvedFunction( + "array", + Seq(Literal("school.teacher"), Literal("Tom"), Literal("Walt")), + isDistinct = false) + val jsonObjExp = + Literal( + "{\"school\":{\"teacher\":[\"Alice\"],\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}}") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index d15f5c8e3..b7d615980 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -385,11 +385,11 @@ JSON_ARRAY: 'JSON_ARRAY'; JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; TO_JSON_STRING: 'TO_JSON_STRING'; JSON_EXTRACT: 'JSON_EXTRACT'; +JSON_DELETE : 'JSON_DELETE'; JSON_KEYS: 'JSON_KEYS'; JSON_VALID: 'JSON_VALID'; -//JSON_APPEND: 'JSON_APPEND'; -//JSON_DELETE: 'JSON_DELETE'; -//JSON_EXTEND: 'JSON_EXTEND'; +JSON_APPEND: 'JSON_APPEND'; +//JSON_EXTEND : 'JSON_EXTEND'; //JSON_SET: 'JSON_SET'; //JSON_ARRAY_ALL_MATCH: 'JSON_ARRAY_ALL_MATCH'; //JSON_ARRAY_ANY_MATCH: 'JSON_ARRAY_ANY_MATCH'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 2466a3d23..b990fd549 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -875,10 +875,10 @@ jsonFunctionName | JSON_ARRAY_LENGTH | TO_JSON_STRING | JSON_EXTRACT + | JSON_DELETE + | JSON_APPEND | JSON_KEYS | JSON_VALID -// | JSON_APPEND -// | JSON_DELETE // | JSON_EXTEND // | JSON_SET // | JSON_ARRAY_ALL_MATCH diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/JsonUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/JsonUtils.java new file mode 100644 index 000000000..9ca6732c6 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/JsonUtils.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public interface JsonUtils { + ObjectMapper objectMapper = new ObjectMapper(); + + static Object parseValue(String value) { + // Try parsing the value as JSON, fallback to primitive if parsing fails + try { + return objectMapper.readValue(value, Object.class); + } catch (Exception e) { + // Primitive value, return as is + return value; + } + } + + /** + * append nested value to the json object + * @param currentObj + * @param pathParts + * @param depth + * @param valueToAppend + */ + static void appendNestedValue(Object currentObj, String[] pathParts, int depth, Object valueToAppend) { + if (currentObj == null || depth >= pathParts.length) { + return; + } + + if (currentObj instanceof Map) { + Map currentMap = (Map) currentObj; + String currentKey = pathParts[depth]; + + if (depth == pathParts.length - 1) { + // If it's the last key, append to the array + currentMap.computeIfAbsent(currentKey, k -> new ArrayList<>()); // Create list if not present + Object existingValue = currentMap.get(currentKey); + + if (existingValue instanceof List) { + List existingList = (List) existingValue; + existingList.add(valueToAppend); + } + } else { + // Continue traversing + currentMap.computeIfAbsent(currentKey, k -> new LinkedHashMap<>()); // Create map if not present + appendNestedValue(currentMap.get(currentKey), pathParts, depth + 1, valueToAppend); + } + } else if (currentObj instanceof List) { + // If the current object is a list, process each map in the list + List list = (List) currentObj; + for (Object item : list) { + if (item instanceof Map) { + appendNestedValue(item, pathParts, depth, valueToAppend); + } + } + } + } + + /** + * remove nested json object using its keys parts + * @param currentObj + * @param keyParts + * @param depth + */ + static void removeNestedKey(Object currentObj, String[] keyParts, int depth) { + if (currentObj == null || depth >= keyParts.length) { + return; + } + + if (currentObj instanceof Map) { + Map currentMap = (Map) currentObj; + String currentKey = keyParts[depth]; + + if (depth == keyParts.length - 1) { + // If it's the last key, remove it from the map + currentMap.remove(currentKey); + } else { + // If not the last key, continue traversing + if (currentMap.containsKey(currentKey)) { + Object nextObj = currentMap.get(currentKey); + + if (nextObj instanceof List) { + // If the value is a list, process each item in the list + List list = (List) nextObj; + for (int i = 0; i < list.size(); i++) { + removeNestedKey(list.get(i), keyParts, depth + 1); + } + } else { + // Continue traversing if it's a map + removeNestedKey(nextObj, keyParts, depth + 1); + } + } + } + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index 2541b3743..e80a26bc4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -8,14 +8,105 @@ import inet.ipaddr.AddressStringException; import inet.ipaddr.IPAddressString; import inet.ipaddr.IPAddressStringParameters; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.types.DataTypes; import scala.Function2; +import scala.Option; import scala.Serializable; +import scala.collection.JavaConverters; +import scala.collection.mutable.WrappedArray; import scala.runtime.AbstractFunction2; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.opensearch.sql.expression.function.JsonUtils.appendNestedValue; +import static org.opensearch.sql.expression.function.JsonUtils.objectMapper; +import static org.opensearch.sql.expression.function.JsonUtils.parseValue; +import static org.opensearch.sql.expression.function.JsonUtils.removeNestedKey; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + public interface SerializableUdf { - Function2 cidrFunction = new SerializableAbstractFunction2<>() { + + abstract class SerializableAbstractFunction2 extends AbstractFunction2 + implements Serializable { + } + + /** + * Remove specified keys from a JSON string. + * + * @param jsonStr The input JSON string. + * @param keysToRemove The list of keys to remove. + * @return A new JSON string without the specified keys. + */ + Function2, String> jsonDeleteFunction = new SerializableAbstractFunction2<>() { + @Override + public String apply(String jsonStr, WrappedArray keysToRemove) { + if (jsonStr == null) { + return null; + } + try { + Map jsonMap = objectMapper.readValue(jsonStr, Map.class); + removeKeys(jsonMap, keysToRemove); + return objectMapper.writeValueAsString(jsonMap); + } catch (Exception e) { + return null; + } + } + + private void removeKeys(Map map, WrappedArray keysToRemove) { + Collection keys = JavaConverters.asJavaCollection(keysToRemove); + for (String key : keys) { + String[] keyParts = key.split("\\."); + removeNestedKey(map, keyParts, 0); + } + } + }; + + Function2, String> jsonAppendFunction = new SerializableAbstractFunction2<>() { + /** + * Append values to JSON arrays based on specified path-values. + * + * @param jsonStr The input JSON string. + * @param elements A list of path-values where the first item is the path and subsequent items are values to append. + * @return The updated JSON string. + */ + public String apply(String jsonStr, WrappedArray elements) { + if (jsonStr == null) { + return null; + } + try { + List pathValues = JavaConverters.mutableSeqAsJavaList(elements); + if (pathValues.isEmpty()) { + return jsonStr; + } + + String path = pathValues.get(0); + String[] pathParts = path.split("\\."); + List values = pathValues.subList(1, pathValues.size()); + + // Parse the JSON string into a Map + Map jsonMap = objectMapper.readValue(jsonStr, Map.class); + + // Append each value at the specified path + for (String value : values) { + Object parsedValue = parseValue(value); // Parse the value + appendNestedValue(jsonMap, pathParts, 0, parsedValue); + } + + // Convert the updated map back to JSON + return objectMapper.writeValueAsString(jsonMap); + } catch (Exception e) { + return null; + } + } + }; + + Function2 cidrFunction = new SerializableAbstractFunction2<>() { IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() .allowEmpty(false) @@ -32,7 +123,7 @@ public Boolean apply(String ipAddress, String cidrBlock) { try { parsedIpAddress.validate(); } catch (AddressStringException e) { - throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + throw new RuntimeException("The given ipAddress '" + ipAddress + "' is invalid. It must be a valid IPv4 or IPv6 address. Error details: " + e.getMessage()); } IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions); @@ -40,18 +131,54 @@ public Boolean apply(String ipAddress, String cidrBlock) { try { parsedCidrBlock.validate(); } catch (AddressStringException e) { - throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage()); + throw new RuntimeException("The given cidrBlock '" + cidrBlock + "' is invalid. It must be a valid CIDR or netmask. Error details: " + e.getMessage()); } - if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { - throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6."); + if (parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { + throw new RuntimeException("The given ipAddress '" + ipAddress + "' and cidrBlock '" + cidrBlock + "' are not compatible. Both must be either IPv4 or IPv6."); } return parsedCidrBlock.contains(parsedIpAddress); } }; - abstract class SerializableAbstractFunction2 extends AbstractFunction2 - implements Serializable { + /** + * get the function reference according to its name + * + * @param funcName + * @return + */ + static ScalaUDF visit(String funcName, List expressions) { + switch (funcName) { + case "cidr": + return new ScalaUDF(cidrFunction, + DataTypes.BooleanType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + case "json_delete": + return new ScalaUDF(jsonDeleteFunction, + DataTypes.StringType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("json_delete"), + false, + true); + case "json_append": + return new ScalaUDF(jsonAppendFunction, + DataTypes.StringType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("json_append"), + false, + true); + default: + return null; + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index bc14ba9d4..d9ace48ba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -11,29 +11,21 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Cast$; -import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.LambdaFunction$; -import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.RowFrame$; -import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable; import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$; -import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; -import org.apache.spark.sql.catalyst.expressions.WindowExpression; -import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -44,7 +36,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -68,9 +59,7 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN; -import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.SerializableUdf; import org.opensearch.sql.ppl.utils.AggregatorTransformer; import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; @@ -89,6 +78,7 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyList; +import static java.util.List.of; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer.createIntervalArgs; @@ -438,17 +428,7 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys Expression ipAddressExpression = context.getNamedParseExpressions().pop(); analyze(node.getCidrBlock(), context); Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); - - ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddressExpression,cidrBlockExpression), - seq(), - Option.empty(), - Option.apply("cidr"), - false, - true); - - return context.getNamedParseExpressions().push(udf); + return context.getNamedParseExpressions().push(SerializableUdf.visit("cidr", of(ipAddressExpression,cidrBlockExpression))); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index 0a4f19b53..f73a1c491 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -13,12 +13,15 @@ import org.apache.spark.sql.catalyst.expressions.DateAddInterval$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.TimestampAdd$; import org.apache.spark.sql.catalyst.expressions.TimestampDiff$; import org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp$; import org.apache.spark.sql.catalyst.expressions.UnaryMinus$; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.CatalystPlanContext; import scala.Option; import java.util.Arrays; @@ -26,7 +29,6 @@ import java.util.Map; import java.util.function.Function; -import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT; @@ -76,7 +78,7 @@ public interface BuiltinFunctionTransformer { * This is only used for the built-in functions between PPL and Spark with different names. * If the built-in function names are the same in PPL and Spark, add it to {@link BuiltinFunctionName} only. */ - static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING + Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING = ImmutableMap.builder() // arithmetic operators .put(ADD, "+") @@ -117,7 +119,7 @@ public interface BuiltinFunctionTransformer { /** * The name mapping between PPL builtin functions to Spark builtin functions. */ - static final Map, Expression>> PPL_TO_SPARK_FUNC_MAPPING + Map, Expression>> PPL_TO_SPARK_FUNC_MAPPING = ImmutableMap., Expression>>builder() // json functions .put( @@ -176,9 +178,11 @@ public interface BuiltinFunctionTransformer { static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { - // TODO change it when UDF is supported - // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions - throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + ScalaUDF udf = SerializableUdf.visit(function.getFuncName(), args); + if(udf == null) { + throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + } + return udf; } else { BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.get(builtin); diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableIPUdfTest.java similarity index 87% rename from ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java rename to ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableIPUdfTest.java index 3d3940730..c11c832c3 100644 --- a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableIPUdfTest.java @@ -1,9 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ package org.opensearch.sql.expression.function; import org.junit.Assert; import org.junit.Test; -public class SerializableUdfTest { +import java.util.Arrays; +import java.util.Collections; + +import static java.util.Collections.singletonList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class SerializableIPUdfTest { @Test(expected = RuntimeException.class) public void cidrNullIpTest() { diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableJsonUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableJsonUdfTest.java new file mode 100644 index 000000000..fb47803cf --- /dev/null +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableJsonUdfTest.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.expression.function; + +import org.junit.Test; +import scala.collection.mutable.WrappedArray; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.sql.expression.function.SerializableUdf.jsonAppendFunction; +import static org.opensearch.sql.expression.function.SerializableUdf.jsonDeleteFunction; + +public class SerializableJsonUdfTest { + + @Test + public void testJsonDeleteFunctionRemoveSingleKey() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}"; + String expectedJson = "{\"key1\":\"value1\",\"key3\":\"value3\"}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[]{"key2"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveNestedKey() { + // Correctly escape double quotes within the JSON string + String jsonStr = "{\"key1\":\"value1\",\"key2\":{ \"key3\":\"value3\",\"key4\":\"value4\" }}"; + String expectedJson = "{\"key1\":\"value1\",\"key2\":{\"key4\":\"value4\"}}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[]{"key2.key3"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveSingleArrayedKey() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\",\"keyArray\":[\"value1\",\"value2\"]}"; + String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[]{"keyArray"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveMultipleKeys() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}"; + String expectedJson = "{\"key3\":\"value3\"}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[]{"key1", "key2"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveMultipleSomeAreNestedKeys() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":{ \"key3\":\"value3\",\"key4\":\"value4\" }}"; + String expectedJson = "{\"key2\":{\"key3\":\"value3\"}}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[]{"key1", "key2.key4"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionRemoveMultipleKeysNestedArrayKeys() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":[{ \"a\":\"valueA\",\"key3\":\"value3\"}, {\"a\":\"valueA\",\"key4\":\"value4\"}]}"; + String expectedJson = "{\"key2\":[{\"key3\":\"value3\"},{\"key4\":\"value4\"}]}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[]{"key1", "key2.a"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonDeleteFunctionNoKeysRemoved() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + String result = jsonDeleteFunction.apply(jsonStr, WrappedArray.make(new String[0])); + assertEquals(jsonStr, result); + } + + @Test + public void testJsonDeleteFunctionNullJson() { + String result = jsonDeleteFunction.apply(null, WrappedArray.make(new String[]{"key1"})); + assertNull(result); + } + + @Test + public void testJsonDeleteFunctionInvalidJson() { + String invalidJson = "invalid_json"; + String result = jsonDeleteFunction.apply(invalidJson, WrappedArray.make(new String[]{"key1"})); + assertNull(result); + } + + @Test + public void testJsonAppendFunctionAppendToExistingArray() { + String jsonStr = "{\"arrayKey\":[\"value1\",\"value2\"]}"; + String expectedJson = "{\"arrayKey\":[\"value1\",\"value2\",\"value3\"]}"; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"arrayKey", "value3"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionAppendObjectToExistingArray() { + String jsonStr = "{\"key1\":\"value1\",\"key2\":[{\"a\":\"valueA\",\"key3\":\"value3\"}]}"; + String expectedJson = "{\"key1\":\"value1\",\"key2\":[{\"a\":\"valueA\",\"key3\":\"value3\"},{\"a\":\"valueA\",\"key4\":\"value4\"}]}"; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"key2", "{\"a\":\"valueA\",\"key4\":\"value4\"}"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionAddNewArray() { + String jsonStr = "{\"key1\":\"value1\",\"newArray\":[]}"; + String expectedJson = "{\"key1\":\"value1\",\"newArray\":[\"newValue\"]}"; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"newArray", "newValue"})); + assertEquals(expectedJson, result); + } + @Test + public void testJsonAppendFunctionNoSuchKey() { + String jsonStr = "{\"key1\":\"value1\"}"; + String expectedJson = "{\"key1\":\"value1\",\"newKey\":[\"newValue\"]}"; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"newKey", "newValue"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionIgnoreNonArrayKey() { + String jsonStr = "{\"key1\":\"value1\"}"; + String expectedJson = jsonStr; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"key1", "newValue"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionWithNestedArrayKeys() { + String jsonStr = "{\"key2\":[{\"a\":[\"Value1\"],\"key3\":\"Value3\"},{\"a\":[\"Value1\"],\"key4\":\"Value4\"}]}"; + String expectedJson = "{\"key2\":[{\"a\":[\"Value1\",\"Value2\"],\"key3\":\"Value3\"},{\"a\":[\"Value1\",\"Value2\"],\"key4\":\"Value4\"}]}"; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"key2.a","Value2"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionWithObjectKey() { + String jsonStr = "{\"key2\":[{\"a\":[\"Value1\"],\"key3\":\"Value3\"},{\"a\":[\"Value1\"],\"key4\":\"Value4\"}]}"; + String expectedJson = "{\"key2\":[{\"a\":[\"Value1\"],\"key3\":\"Value3\"},{\"a\":[\"Value1\"],\"key4\":\"Value4\"},\"Value2\"]}"; + String result = jsonAppendFunction.apply(jsonStr, WrappedArray.make(new String[]{"key2","Value2"})); + assertEquals(expectedJson, result); + } + + @Test + public void testJsonAppendFunctionNullJson() { + String result = jsonAppendFunction.apply(null, WrappedArray.make(new String[]{"key1", "newValue"})); + assertNull(result); + } + + @Test + public void testJsonAppendFunctionInvalidJson() { + String invalidJson = "invalid_json"; + String result = jsonAppendFunction.apply(invalidJson, WrappedArray.make(new String[]{"key1", "newValue"})); + assertNull(result); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala index 6193bc43f..fae070a75 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJsonFunctionsTranslatorTestSuite.scala @@ -5,15 +5,21 @@ package org.opensearch.flint.spark.ppl +import java.util + import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.expression.function.SerializableUdf.visit import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} +import org.apache.spark.sql.types.DataTypes class PPLLogicalPlanJsonFunctionsTranslatorTestSuite extends SparkFunSuite @@ -185,6 +191,51 @@ class PPLLogicalPlanJsonFunctionsTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test json_delete()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval result = json_delete('{"a":[{"b":1},{"c":2}]}', array('a.b'))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val keysExpression = + UnresolvedFunction("array", Seq(Literal("a.b")), isDistinct = false) + val jsonObjExp = + Literal("""{"a":[{"b":1},{"c":2}]}""") + val jsonFunc = + Alias(visit("json_delete", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, false) + } + + test("test json_append()") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + """source=t | eval result = json_append('{"a":[{"b":1},{"c":2}]}', array('a.b','c','d'))"""), + context) + + val table = UnresolvedRelation(Seq("t")) + val keysExpression = + UnresolvedFunction( + "array", + Seq(Literal("a.b"), Literal("c"), Literal("d")), + isDistinct = false) + val jsonObjExp = + Literal("""{"a":[{"b":1},{"c":2}]}""") + val jsonFunc = + Alias(visit("json_append", util.List.of(jsonObjExp, keysExpression)), "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), eval) + comparePlans(expectedPlan, logPlan, false) + } + test("test json_keys()") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala index 213f201cc..c8a8a67ad 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseCidrmatchTestSuite.scala @@ -7,13 +7,14 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.expression.function.SerializableUdf.visit import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.DataTypes @@ -41,15 +42,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(false)) val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(true)) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val expectedPlan = Project( Seq(UnresolvedStar(None)), @@ -71,15 +64,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) val filterIsValid = EqualTo(UnresolvedAttribute("isValid"), Literal(false)) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val expectedPlan = Project( Seq(UnresolvedStar(None)), @@ -100,15 +85,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val cidrExpression = Literal("2003:db8::/32") val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val expectedPlan = Project( Seq(UnresolvedAttribute("ip")), @@ -130,15 +107,7 @@ class PPLLogicalPlanParseCidrmatchTestSuite val filterIpv6 = EqualTo(UnresolvedAttribute("isV6"), Literal(true)) val filterClause = Filter(filterIpv6, UnresolvedRelation(Seq("t"))) - val cidr = ScalaUDF( - SerializableUdf.cidrFunction, - DataTypes.BooleanType, - seq(ipAddress, cidrExpression), - seq(), - Option.empty, - Option.apply("cidr"), - false, - true) + val cidr = visit("cidr", java.util.List.of(ipAddress, cidrExpression)) val equalTo = EqualTo(Literal(true), cidr) val caseFunction = CaseWhen(Seq((equalTo, Literal("in"))), Literal("out"))