diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java new file mode 100644 index 0000000000000..ca2ae80042df7 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.json; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; + +import org.apache.spark.sql.catalyst.expressions.SharedFactory; +import org.apache.spark.sql.catalyst.json.CreateJacksonParser; +import org.apache.spark.unsafe.types.UTF8String; + +public class JsonExpressionUtils { + + public static Integer lengthOfJsonArray(UTF8String json) { + // return null for null input + if (json == null) { + return null; + } + try (JsonParser jsonParser = + CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) { + if (jsonParser.nextToken() == null) { + return null; + } + // Only JSON array are supported for this function. + if (jsonParser.currentToken() != JsonToken.START_ARRAY) { + return null; + } + // Parse the array to compute its length. + int length = 0; + // Keep traversing until the end of JSON array + while (jsonParser.nextToken() != JsonToken.END_ARRAY) { + length += 1; + // skip all the child of inner object or array + jsonParser.skipChildren(); + } + return length; + } catch (IOException e) { + return null; + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bdcf3f0c1eeab..e1f2b1c1df62a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern} @@ -967,54 +969,26 @@ case class SchemaOfJson( group = "json_funcs", since = "3.1.0" ) -case class LengthOfJsonArray(child: Expression) extends UnaryExpression - with CodegenFallback with ExpectsInputTypes { +case class LengthOfJsonArray(child: Expression) + extends UnaryExpression + with ExpectsInputTypes + with RuntimeReplaceable { override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" - override def eval(input: InternalRow): Any = { - val json = child.eval(input).asInstanceOf[UTF8String] - // return null for null input - if (json == null) { - return null - } - - try { - Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) { - parser => { - // return null if null array is encountered. - if (parser.nextToken() == null) { - return null - } - // Parse the array to compute its length. - parseCounter(parser, input) - } - } - } catch { - case _: JsonProcessingException | _: IOException => null - } - } - - private def parseCounter(parser: JsonParser, input: InternalRow): Any = { - var length = 0 - // Only JSON array are supported for this function. - if (parser.currentToken != JsonToken.START_ARRAY) { - return null - } - // Keep traversing until the end of JSON array - while(parser.nextToken() != JsonToken.END_ARRAY) { - length += 1 - // skip all the child of inner object or array - parser.skipChildren() - } - length - } - override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray = copy(child = newChild) + + override def replacement: Expression = StaticInvoke( + classOf[JsonExpressionUtils], + dataType, + "lengthOfJsonArray", + Seq(child), + inputTypes + ) } /** diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain index 50ab91560e64a..d70e2eb60aba5 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain @@ -1,2 +1,2 @@ -Project [json_array_length(g#0) AS json_array_length(g)#0] +Project [static_invoke(JsonExpressionUtils.lengthOfJsonArray(g#0)) AS json_array_length(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]