Skip to content

Commit

Permalink
[SPARK-48712][SQL][FOLLOWUP] Check whether input is valid utf-8 strin…
Browse files Browse the repository at this point in the history
…g or not before entering fast path

### What changes were proposed in this pull request?

Check whether input is valid utf-8 string or not before entering fast path

### Why are the changes needed?

Avoid behavior change on a corner case where users provide invalid UTF-8 strings for UTF-8 encoding

### Does this PR introduce _any_ user-facing change?
no, this is a followup to avoid potential breaking change

### How was this patch tested?
existing tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48203 from yaooqinn/SPARK-48712.

Authored-by: Kent Yao <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
  • Loading branch information
yaooqinn committed Sep 23, 2024
1 parent 44ec70f commit e1637e3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3039,10 +3039,9 @@ object Encode {
legacyCharsets: Boolean,
legacyErrorAction: Boolean): Array[Byte] = {
val toCharset = charset.toString
if (input.numBytes == 0 || "UTF-8".equalsIgnoreCase(toCharset)) {
return input.getBytes
}
if ("UTF-8".equalsIgnoreCase(toCharset) && input.isValid) return input.getBytes
val encoder = CharsetProvider.newEncoder(toCharset, legacyCharsets, legacyErrorAction)
if (input.numBytes == 0) return input.getBytes
try {
val bb = encoder.encode(CharBuffer.wrap(input.toString))
JavaUtils.bufferToArray(bb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.CharsetProvider
import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -2076,4 +2079,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
)
}

test("SPARK-48712: Check whether input is valid utf-8 string or not before entering fast path") {
val str = UTF8String.fromBytes(Array[Byte](-1, -2, -3, -4))
assert(!str.isValid, "please use a string that is not valid UTF-8 for testing")
val expected = Array[Byte](-17, -65, -67, -17, -65, -67, -17, -65, -67, -17, -65, -67)
val bytes = Encode.encode(str, UTF8String.fromString("UTF-8"), false, false)
assert(bytes === expected)
checkEvaluation(Encode(Literal(str), Literal("UTF-8")), expected)
checkEvaluation(Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-8")), Array.emptyByteArray)
checkErrorInExpression[SparkIllegalArgumentException](
Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-12345")),
condition = "INVALID_PARAMETER_VALUE.CHARSET",
parameters = Map(
"charset" -> "UTF-12345",
"functionName" -> toSQLId("encode"),
"parameter" -> toSQLId("charset"),
"charsets" -> CharsetProvider.VALID_CHARSETS.mkString(", ")))
}
}

0 comments on commit e1637e3

Please sign in to comment.