Skip to content

Commit

Permalink
[SPARK-49766][SQL] Codegen Support for json_array_length (by `Invok…
Browse files Browse the repository at this point in the history
…e` & `RuntimeReplaceable`)

### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `json_array_length`.

### Why are the changes needed?
- improve codegen coverage.
- simplified code.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA & Existed UT (eg: JsonFunctionsSuite#`json_array_length function`)

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48224 from panbingkun/SPARK-49766.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
panbingkun authored and MaxGekk committed Oct 12, 2024
1 parent 6734d48 commit 1244c5a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.json;

import java.io.IOException;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;

import org.apache.spark.sql.catalyst.expressions.SharedFactory;
import org.apache.spark.sql.catalyst.json.CreateJacksonParser;
import org.apache.spark.unsafe.types.UTF8String;

public class JsonExpressionUtils {

public static Integer lengthOfJsonArray(UTF8String json) {
// return null for null input
if (json == null) {
return null;
}
try (JsonParser jsonParser =
CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) {
if (jsonParser.nextToken() == null) {
return null;
}
// Only JSON array are supported for this function.
if (jsonParser.currentToken() != JsonToken.START_ARRAY) {
return null;
}
// Parse the array to compute its length.
int length = 0;
// Keep traversing until the end of JSON array
while (jsonParser.nextToken() != JsonToken.END_ARRAY) {
length += 1;
// skip all the child of inner object or array
jsonParser.skipChildren();
}
return length;
} catch (IOException e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern}
Expand Down Expand Up @@ -967,54 +969,26 @@ case class SchemaOfJson(
group = "json_funcs",
since = "3.1.0"
)
case class LengthOfJsonArray(child: Expression) extends UnaryExpression
with CodegenFallback with ExpectsInputTypes {
case class LengthOfJsonArray(child: Expression)
extends UnaryExpression
with ExpectsInputTypes
with RuntimeReplaceable {

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def dataType: DataType = IntegerType
override def nullable: Boolean = true
override def prettyName: String = "json_array_length"

override def eval(input: InternalRow): Any = {
val json = child.eval(input).asInstanceOf[UTF8String]
// return null for null input
if (json == null) {
return null
}

try {
Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) {
parser => {
// return null if null array is encountered.
if (parser.nextToken() == null) {
return null
}
// Parse the array to compute its length.
parseCounter(parser, input)
}
}
} catch {
case _: JsonProcessingException | _: IOException => null
}
}

private def parseCounter(parser: JsonParser, input: InternalRow): Any = {
var length = 0
// Only JSON array are supported for this function.
if (parser.currentToken != JsonToken.START_ARRAY) {
return null
}
// Keep traversing until the end of JSON array
while(parser.nextToken() != JsonToken.END_ARRAY) {
length += 1
// skip all the child of inner object or array
parser.skipChildren()
}
length
}

override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray =
copy(child = newChild)

override def replacement: Expression = StaticInvoke(
classOf[JsonExpressionUtils],
dataType,
"lengthOfJsonArray",
Seq(child),
inputTypes
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [json_array_length(g#0) AS json_array_length(g)#0]
Project [static_invoke(JsonExpressionUtils.lengthOfJsonArray(g#0)) AS json_array_length(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 comments on commit 1244c5a

Please sign in to comment.