From 7bae01ca367728da13ee613c5bffc0ce26ac660f Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 15 Oct 2024 15:51:16 +0800 Subject: [PATCH] first commit Signed-off-by: Lantao Jin --- .../flint/spark/FlintSparkSuite.scala | 26 ++ .../FlintSparkPPLJsonFunctionITSuite.scala | 359 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 18 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 22 ++ .../function/BuiltinFunctionName.java | 19 + .../ppl/utils/BuiltinFunctionTranslator.java | 60 ++- 6 files changed, 500 insertions(+), 4 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 1ecf48d28..3e340747a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -642,4 +642,30 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (6, 403, '/home', null) | """.stripMargin) } + + protected def createNullableJsonContentTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | jString STRING, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, '{"account_number":1,"balance":39225,"age":32,"gender":"M"}', true), + | (2, '{"f1":"abc","f2":{"f3":"a","f4":"b"}}', true), + | (3, '[1,2,3,{"f1":1,"f2":[5,6]},4]', true), + | (4, '[]', true), + | (5, '{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}', true), + | (6, '[1,2', false), + | (7, '[invalid json]', false), + | (8, '{"invalid": "json"', false), + | (9, 'invalid json', false), + | (0, null, false) + | """.stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala new file mode 100644 index 000000000..3a0a5516c --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJsonFunctionITSuite.scala @@ -0,0 +1,359 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal, Not} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLJsonFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + private val validJson1 = "{\"account_number\":1,\"balance\":39225,\"age\":32,\"gender\":\"M\"}" + private val validJson2 = "{\"f1\":\"abc\",\"f2\":{\"f3\":\"a\",\"f4\":\"b\"}}" + private val validJson3 = "[1,2,3,{\"f1\":1,\"f2\":[5,6]},4]" + private val validJson4 = "[]" + private val validJson5 = + "{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}" + private val invalidJson1 = "[1,2" + private val invalidJson2 = "[invalid json]" + private val invalidJson3 = "{\"invalid\": \"json\"" + private val invalidJson4 = "invalid json" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test json() function: valid JSON") { + Seq(validJson1, validJson2, validJson3, validJson4, validJson5).foreach { jsonStr => + val frame = sql(s""" + | source = $testTable + | | eval result = json('$jsonStr') | head 1 | fields result + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(jsonStr)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(Literal(jsonStr), Literal("$")), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("test json() function: invalid JSON") { + Seq(invalidJson1, invalidJson2, invalidJson3, invalidJson4).foreach { jsonStr => + val frame = sql(s""" + | source = $testTable + | | eval result = json('$jsonStr') | head 1 | fields result + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(null)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(Literal(jsonStr), Literal("$")), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + } + + test("test json() function on field") { + val frame = sql(s""" + | source = $testTable + | | where isValid = true | eval result = json(jString) | fields result + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Seq(validJson1, validJson2, validJson3, validJson4, validJson5).map(Row.apply(_)).toArray + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val frame2 = sql(s""" + | source = $testTable + | | where isValid = false | eval result = json(jString) | fields result + | """.stripMargin) + val results2: Array[Row] = frame2.collect() + val expectedResults2: Array[Row] = + Array(Row(null), Row(null), Row(null), Row(null), Row(null)) + assert(results2.sameElements(expectedResults2)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_array()") { + // test string array + var frame = sql(s""" + | source = $testTable | eval result = json_array('this', 'is', 'a', 'string', 'array') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows( + Seq(Row(Array("this", "is", "a", "string", "array"))), + frame.collect().toSeq) + + // test empty array + frame = sql(s""" + | source = $testTable | eval result = json_array() | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(Array())), frame.collect().toSeq) + + // test number array + frame = sql(s""" + | source = $testTable | eval result = json_array(1, 2, 0, -1, 1.1, -0.11) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(Array(1.0, 2.0, 0.0, -1.0, 1.1, -0.11))), frame.collect().toSeq) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "array", + Seq(Literal(1), Literal(2), Literal(0), Literal(-1), Literal(1.1), Literal(-0.11)), + isDistinct = false)), + isDistinct = false), + "result")() + val eval = Project(Seq(UnresolvedStar(None), jsonFunc), table) + val limit = GlobalLimit(Literal(1), LocalLimit(Literal(1), eval)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), limit) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + // item in json_array should all be the same type + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval result = json_array('this', 'is', 1.1, -0.11, true, false) | head 1 | fields result + | """.stripMargin)) + assert(ex.getMessage().contains("should all be the same type")) + } + + test("test json_array_length()") { + var frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array('this', 'is', 'a', 'string', 'array')) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(5)), frame.collect().toSeq) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array(1, 2, 0, -1, 1.1, -0.11)) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(6)), frame.collect().toSeq) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length(json_array()) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(0)), frame.collect().toSeq) + + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[]') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(0)), frame.collect().toSeq) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2,3,4]') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(4)), frame.collect().toSeq) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2,3,{"f1":1,"f2":[5,6]},4]') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(5)), frame.collect().toSeq) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('{\"key\": 1}') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(null)), frame.collect().toSeq) + frame = sql(s""" + | source = $testTable | eval result = json_array_length('[1,2') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row(null)), frame.collect().toSeq) + } + + test("test json_object()") { + // test value is a string + var frame = sql(s""" + | source = $testTable | eval result = json_object('key', 'string') | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row("""{"key":"string"}""")), frame.collect().toSeq) + + // test value is a number + frame = sql(s""" + | source = $testTable | eval result = json_object('key', 123.45) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row("""{"key":123.45}""")), frame.collect().toSeq) + + // test value is a boolean + frame = sql(s""" + | source = $testTable | eval result = json_object('key', true) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row("""{"key":true}""")), frame.collect().toSeq) + + // test value is an empty array + frame = sql(s""" + | source = $testTable | eval result = json_object('key', json_array()) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row("""{"key":[]}""")), frame.collect().toSeq) + + // test value is an array + frame = sql(s""" + | source = $testTable | eval result = json_object('key', json_array(1, 2, 3)) | head 1 | fields result + | """.stripMargin) + QueryTest.sameRows(Seq(Row("""{"key":[1,2,3]}""")), frame.collect().toSeq) + + // test value is an another json + frame = sql(s""" + | source = $testTable + | | where isValid = true + | | eval result = json_object('key', json(jString)) | fields result + | """.stripMargin) + val expectedRows = Seq( + Row("""{"key":"{"account_number":1,"balance":39225,"age":32,"gender":"M"}"""), + Row("""{"key":{"f1":"abc", "f2":{"f3":"a", "f4":"b"}}}"""), + Row("""{"key":[1, 2, 3, {"f1":1, "f2":[5, 6]}, 4]}"""), + Row("""{"key":[]}"""), + Row( + """{"key":{"teacher":"Alice", "student":[{"name":"Bob", "rank":1}, {"name":"Charlie", "rank":2}]}}""")) + QueryTest.sameRows(expectedRows, frame.collect().toSeq) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "to_json", + Seq( + UnresolvedFunction( + "named_struct", + Seq( + Literal("key"), + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false)), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_valid()") { + val frame = sql(s""" + | source = $testTable + | | where json_valid(jString) | fields jString + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Seq(validJson1, validJson2, validJson3, validJson4, validJson5).map(Row.apply(_)).toArray + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val frame2 = sql(s""" + | source = $testTable + | | where not json_valid(jString) | fields jString + | """.stripMargin) + val results2: Array[Row] = frame2.collect() + val expectedResults2: Array[Row] = + Seq(invalidJson1, invalidJson2, invalidJson3, invalidJson4, null).map(Row.apply(_)).toArray + assert(results2.sameElements(expectedResults2)) + + val logicalPlan: LogicalPlan = frame2.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = + UnresolvedFunction( + "isnotnull", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false) + val where = Filter(Not(jsonFunc), table) + val expectedPlan = Project(Seq(UnresolvedAttribute("jString")), where) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_keys()") { + val frame = sql(s""" + | source = $testTable + | | where isValid = true + | | eval result = json_keys(json(jString)) | fields result + | """.stripMargin) + val expectedRows = Seq( + Row(Array("account_number", "balance", "age", "gender")), + Row(Array("f1", "f2")), + Row(null), + Row(null), + Row(Array("teacher", "student"))) + QueryTest.sameRows(expectedRows, frame.collect().toSeq) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val jsonFunc = Alias( + UnresolvedFunction( + "json_object_keys", + Seq( + UnresolvedFunction( + "get_json_object", + Seq(UnresolvedAttribute("jString"), Literal("$")), + isDistinct = false)), + isDistinct = false), + "result")() + val eval = Project( + Seq(UnresolvedStar(None), jsonFunc), + Filter(EqualTo(UnresolvedAttribute("isValid"), Literal(true)), table)) + val expectedPlan = Project(Seq(UnresolvedAttribute("result")), eval) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test json_extract()") { + + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 2b916a245..1b99abb3f 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -357,6 +357,24 @@ CAST: 'CAST'; ISEMPTY: 'ISEMPTY'; ISBLANK: 'ISBLANK'; +// JSON TEXT FUNCTIONS +JSON: 'JSON'; +JSON_OBJECT: 'JSON_OBJECT'; +JSON_ARRAY: 'JSON_ARRAY'; +JSON_ARRAY_LENGTH: 'JSON_ARRAY_LENGTH'; +JSON_EXTRACT: 'JSON_EXTRACT'; +JSON_KEYS: 'JSON_KEYS'; +JSON_VALID: 'JSON_VALID'; +//JSON_APPEND: 'JSON_APPEND'; +//JSON_DELETE: 'JSON_DELETE'; +//JSON_EXTEND: 'JSON_EXTEND'; +//JSON_SET: 'JSON_SET'; +//JSON_ARRAY_ALL_MATCH: 'JSON_ALL_MATCH'; +//JSON_ARRAY_ANY_MATCH: 'JSON_ANY_MATCH'; +//JSON_ARRAY_FILTER: 'JSON_FILTER'; +//JSON_ARRAY_MAP: 'JSON_ARRAY_MAP'; +//JSON_ARRAY_REDUCE: 'JSON_ARRAY_REDUCE'; + // BOOL FUNCTIONS LIKE: 'LIKE'; ISNULL: 'ISNULL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 7a6f14839..86f1ffddc 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -498,6 +498,7 @@ evalFunctionName | systemFunctionName | positionFunctionName | coalesceFunctionName + | jsonFunctionName ; functionArgs @@ -745,6 +746,7 @@ conditionFunctionBase | IFNULL | NULLIF | ISPRESENT + | JSON_VALID ; systemFunctionName @@ -773,6 +775,25 @@ textFunctionName | ISBLANK ; +jsonFunctionName + : JSON + | JSON_OBJECT + | JSON_ARRAY + | JSON_ARRAY_LENGTH + | JSON_EXTRACT + | JSON_KEYS + | JSON_VALID +// | JSON_APPEND +// | JSON_DELETE +// | JSON_EXTEND +// | JSON_SET +// | JSON_ARRAY_ALL_MATCH +// | JSON_ARRAY_ANY_MATCH +// | JSON_ARRAY_FILTER +// | JSON_ARRAY_MAP +// | JSON_ARRAY_REDUCE + ; + positionFunctionName : POSITION ; @@ -941,6 +962,7 @@ keywordsCanBeId | intervalUnit | dateTimeFunctionName | textFunctionName + | jsonFunctionName | mathematicalFunctionName | positionFunctionName // commands diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 6b549663a..91874052f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -198,6 +198,25 @@ public enum BuiltinFunctionName { TRIM(FunctionName.of("trim")), UPPER(FunctionName.of("upper")), + /** JSON Functions. */ + // If the function argument is a valid JSON, return itself, or return NULL + JSON(FunctionName.of("json")), + JSON_OBJECT(FunctionName.of("json_object")), + JSON_ARRAY(FunctionName.of("json_array")), + JSON_ARRAY_LENGTH(FunctionName.of("json_array_length")), + JSON_EXTRACT(FunctionName.of("json_extract")), + JSON_KEYS(FunctionName.of("json_keys")), + JSON_VALID(FunctionName.of("json_valid")), +// JSON_DELETE(FunctionName.of("json_delete")), +// JSON_APPEND(FunctionName.of("json_append")), +// JSON_EXTEND(FunctionName.of("json_extend")), +// JSON_SET(FunctionName.of("json_set")), +// JSON_ARRAY_ALL_MATCH(FunctionName.of("json_array_all_match")), +// JSON_ARRAY_ANY_MATCH(FunctionName.of("json_array_any_match")), +// JSON_ARRAY_FILTER(FunctionName.of("json_array_filter")), +// JSON_ARRAY_MAP(FunctionName.of("json_array_map")), +// JSON_ARRAY_REDUCE(FunctionName.of("json_array_reduce")), + /** NULL Test. */ IS_NULL(FunctionName.of("is null")), IS_NOT_NULL(FunctionName.of("is not null")), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java index d817305a9..3a1340e20 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -7,17 +7,26 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal$; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; import java.util.Map; +import java.util.function.Function; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.COALESCE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_ARRAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_EXTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_KEYS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_OBJECT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.JSON_VALID; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; @@ -45,7 +54,7 @@ public interface BuiltinFunctionTranslator { * The name mapping between PPL builtin functions to Spark builtin functions. */ static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING - = new ImmutableMap.Builder() + = ImmutableMap.builder() // arithmetic operators .put(ADD, "+") .put(SUBTRACT, "-") @@ -66,15 +75,50 @@ public interface BuiltinFunctionTranslator { .put(ADDDATE, "date_add") // only maps adddate(date, days) .put(DATEDIFF, "datediff") .put(LOCALTIME, "localtimestamp") - //condition functions + // condition functions .put(IS_NULL, "isnull") .put(IS_NOT_NULL, "isnotnull") .put(BuiltinFunctionName.ISPRESENT, "isnotnull") .put(COALESCE, "coalesce") .put(LENGTH, "length") .put(TRIM, "trim") + // json functions + .put(JSON_KEYS, "json_object_keys") + .put(JSON_EXTRACT, "get_json_object") .build(); + /** + * The name mapping between PPL builtin functions to Spark builtin functions. + */ + static final Map, UnresolvedFunction>> PPL_TO_SPARK_FUNC_MAPPING + = ImmutableMap., UnresolvedFunction>>builder() + .put( + JSON_ARRAY, + args -> { + return UnresolvedFunction$.MODULE$.apply("to_json", + seq(UnresolvedFunction$.MODULE$.apply("array", seq(args), false)), false); + }) + .put( + JSON_OBJECT, + args -> { + return UnresolvedFunction$.MODULE$.apply("to_json", + seq(UnresolvedFunction$.MODULE$.apply("named_struct", seq(args), false)), false); + }) + .put( + JSON, + args -> { + return UnresolvedFunction$.MODULE$.apply("get_json_object", + seq(args.get(0), Literal$.MODULE$.apply("$")), false); + }) + .put( + JSON_VALID, + args -> { + return UnresolvedFunction$.MODULE$.apply("isnotnull", + seq(UnresolvedFunction$.MODULE$.apply("get_json_object", + seq(args.get(0), Literal$.MODULE$.apply("$")), false)), false); + }) + .build(); + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { // TODO change it when UDF is supported @@ -82,8 +126,16 @@ static Expression builtinFunction(org.opensearch.sql.ast.expression.Function fun throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); } else { BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); - String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING - .getOrDefault(builtin, builtin.getName().getFunctionName()); + String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.get(builtin); + if (name != null) { + // there is a Spark builtin function mapping with the PPL builtin function + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + Function, UnresolvedFunction> alternative = PPL_TO_SPARK_FUNC_MAPPING.get(builtin); + if (alternative != null) { + return alternative.apply(args); + } + name = builtin.getName().getFunctionName(); return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } }