Skip to content

Commit

Permalink
[SPARK-49939][SQL] Codegen Support for json_object_keys (by Invoke & …
Browse files Browse the repository at this point in the history
…RuntimeReplaceable)

### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `json_object_keys`.

### Why are the changes needed?
- improve codegen coverage.
- simplified code.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA & Existed UT (eg: JsonFunctionsSuite#`json_object_keys function`)

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48428 from panbingkun/SPARK-49939.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
panbingkun authored and MaxGekk committed Oct 13, 2024
1 parent 083f44d commit 54fd408
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.sql.catalyst.expressions.json;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;

import org.apache.spark.sql.catalyst.expressions.SharedFactory;
import org.apache.spark.sql.catalyst.json.CreateJacksonParser;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
import org.apache.spark.unsafe.types.UTF8String;

public class JsonExpressionUtils {
Expand Down Expand Up @@ -55,4 +58,32 @@ public static Integer lengthOfJsonArray(UTF8String json) {
return null;
}
}

public static GenericArrayData jsonObjectKeys(UTF8String json) {
// return null for `NULL` input
if (json == null) {
return null;
}
try (JsonParser jsonParser =
CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) {
// return null if an empty string or any other valid JSON string is encountered
if (jsonParser.nextToken() == null || jsonParser.currentToken() != JsonToken.START_OBJECT) {
return null;
}
// Parse the JSON string to get all the keys of outermost JSON object
List<UTF8String> arrayBufferOfKeys = new ArrayList<>();

// traverse until the end of input and ensure it returns valid key
while (jsonParser.nextValue() != null && jsonParser.currentName() != null) {
// add current fieldName to the ArrayBuffer
arrayBufferOfKeys.add(UTF8String.fromString(jsonParser.currentName()));

// skip all the children of inner object or array
jsonParser.skipChildren();
}
return new GenericArrayData(arrayBufferOfKeys.toArray());
} catch (IOException e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import java.io._

import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers

import com.fasterxml.jackson.core._
Expand Down Expand Up @@ -1014,50 +1013,23 @@ case class LengthOfJsonArray(child: Expression)
group = "json_funcs",
since = "3.1.0"
)
case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback
with ExpectsInputTypes {
case class JsonObjectKeys(child: Expression)
extends UnaryExpression
with ExpectsInputTypes
with RuntimeReplaceable {

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)
override def nullable: Boolean = true
override def prettyName: String = "json_object_keys"

override def eval(input: InternalRow): Any = {
val json = child.eval(input).asInstanceOf[UTF8String]
// return null for `NULL` input
if(json == null) {
return null
}

try {
Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) {
parser => {
// return null if an empty string or any other valid JSON string is encountered
if (parser.nextToken() == null || parser.currentToken() != JsonToken.START_OBJECT) {
return null
}
// Parse the JSON string to get all the keys of outermost JSON object
getJsonKeys(parser, input)
}
}
} catch {
case _: JsonProcessingException | _: IOException => null
}
}

private def getJsonKeys(parser: JsonParser, input: InternalRow): GenericArrayData = {
val arrayBufferOfKeys = ArrayBuffer.empty[UTF8String]

// traverse until the end of input and ensure it returns valid key
while(parser.nextValue() != null && parser.currentName() != null) {
// add current fieldName to the ArrayBuffer
arrayBufferOfKeys += UTF8String.fromString(parser.currentName)

// skip all the children of inner object or array
parser.skipChildren()
}
new GenericArrayData(arrayBufferOfKeys.toArray[Any])
}
override def replacement: Expression = StaticInvoke(
classOf[JsonExpressionUtils],
dataType,
"jsonObjectKeys",
Seq(child),
inputTypes
)

override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys =
copy(child = newChild)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [json_object_keys(g#0) AS json_object_keys(g)#0]
Project [static_invoke(JsonExpressionUtils.jsonObjectKeys(g#0)) AS json_object_keys(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 comments on commit 54fd408

Please sign in to comment.