diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index da6d786efb4e3..786c3968be0fe 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -3039,10 +3039,9 @@ object Encode { legacyCharsets: Boolean, legacyErrorAction: Boolean): Array[Byte] = { val toCharset = charset.toString - if (input.numBytes == 0 || "UTF-8".equalsIgnoreCase(toCharset)) { - return input.getBytes - } + if ("UTF-8".equalsIgnoreCase(toCharset) && input.isValid) return input.getBytes val encoder = CharsetProvider.newEncoder(toCharset, legacyCharsets, legacyErrorAction) + if (input.numBytes == 0) return input.getBytes try { val bb = encoder.encode(CharBuffer.wrap(input.toString)) JavaUtils.bufferToArray(bb) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 29b878230472d..9b454ba764f92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -26,9 +26,12 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.CharsetProvider +import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -2076,4 +2079,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) ) } + + test("SPARK-48712: Check whether input is valid utf-8 string or not before entering fast path") { + val str = UTF8String.fromBytes(Array[Byte](-1, -2, -3, -4)) + assert(!str.isValid, "please use a string that is not valid UTF-8 for testing") + val expected = Array[Byte](-17, -65, -67, -17, -65, -67, -17, -65, -67, -17, -65, -67) + val bytes = Encode.encode(str, UTF8String.fromString("UTF-8"), false, false) + assert(bytes === expected) + checkEvaluation(Encode(Literal(str), Literal("UTF-8")), expected) + checkEvaluation(Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-8")), Array.emptyByteArray) + checkErrorInExpression[SparkIllegalArgumentException]( + Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-12345")), + condition = "INVALID_PARAMETER_VALUE.CHARSET", + parameters = Map( + "charset" -> "UTF-12345", + "functionName" -> toSQLId("encode"), + "parameter" -> toSQLId("charset"), + "charsets" -> CharsetProvider.VALID_CHARSETS.mkString(", "))) + } }