From e1637e3fbe0a7ee6492cfc909ef13fc1fe0534d1 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 23 Sep 2024 19:51:21 +0800 Subject: [PATCH] [SPARK-48712][SQL][FOLLOWUP] Check whether input is valid utf-8 string or not before entering fast path ### What changes were proposed in this pull request? Check whether input is valid utf-8 string or not before entering fast path ### Why are the changes needed? Avoid behavior change on a corner case where users provide invalid UTF-8 strings for UTF-8 encoding ### Does this PR introduce _any_ user-facing change? no, this is a followup to avoid potential breaking change ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48203 from yaooqinn/SPARK-48712. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../expressions/stringExpressions.scala | 5 ++--- .../expressions/StringExpressionsSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index da6d786efb4e3..786c3968be0fe 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -3039,10 +3039,9 @@ object Encode { legacyCharsets: Boolean, legacyErrorAction: Boolean): Array[Byte] = { val toCharset = charset.toString - if (input.numBytes == 0 || "UTF-8".equalsIgnoreCase(toCharset)) { - return input.getBytes - } + if ("UTF-8".equalsIgnoreCase(toCharset) && input.isValid) return input.getBytes val encoder = CharsetProvider.newEncoder(toCharset, legacyCharsets, legacyErrorAction) + if (input.numBytes == 0) return input.getBytes try { val bb = encoder.encode(CharBuffer.wrap(input.toString)) JavaUtils.bufferToArray(bb) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 29b878230472d..9b454ba764f92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -26,9 +26,12 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.CharsetProvider +import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -2076,4 +2079,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) ) } + + test("SPARK-48712: Check whether input is valid utf-8 string or not before entering fast path") { + val str = UTF8String.fromBytes(Array[Byte](-1, -2, -3, -4)) + assert(!str.isValid, "please use a string that is not valid UTF-8 for testing") + val expected = Array[Byte](-17, -65, -67, -17, -65, -67, -17, -65, -67, -17, -65, -67) + val bytes = Encode.encode(str, UTF8String.fromString("UTF-8"), false, false) + assert(bytes === expected) + checkEvaluation(Encode(Literal(str), Literal("UTF-8")), expected) + checkEvaluation(Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-8")), Array.emptyByteArray) + checkErrorInExpression[SparkIllegalArgumentException]( + Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-12345")), + condition = "INVALID_PARAMETER_VALUE.CHARSET", + parameters = Map( + "charset" -> "UTF-12345", + "functionName" -> toSQLId("encode"), + "parameter" -> toSQLId("charset"), + "charsets" -> CharsetProvider.VALID_CHARSETS.mkString(", "))) + } }