Skip to content

Commit

Permalink
[SPARK-47539][SQL] Make the return value of method castToString be …
Browse files Browse the repository at this point in the history
…`Any => UTF8String`

### What changes were proposed in this pull request?
The pr aims to:
- make the method `castToString(from: DataType): Any => Any` to `castToString(from: DataType): Any => UTF8String` in `ToStringBase`.
- Add UT for `ToPrettyString` to improve the `coverage` of UT.

### Why are the changes needed?
- Let the method return a UTF8String(`Any` -> `UTF8String`), which is more intuitive.
- Currently, `ToPrettyString` lacks the corresponding `UT`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
- Add new UT.
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#45688 from panbingkun/ToPrettyString_improve.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
panbingkun authored and HyukjinKwon committed Mar 25, 2024
1 parent 4a1f241 commit b341787
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None)

override protected def useHexFormatForBinary: Boolean = true

private[this] lazy val castFunc: Any => Any = castToString(child.dataType)
private[this] lazy val castFunc: Any => UTF8String = castToString(child.dataType)

override def eval(input: InternalRow): Any = {
val v = child.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
protected def useHexFormatForBinary: Boolean

// Makes the function accept Any type input by doing `asInstanceOf[T]`.
@inline private def acceptAny[T](func: T => Any): Any => Any = i => func(i.asInstanceOf[T])
@inline private def acceptAny[T](func: T => UTF8String): Any => UTF8String =
i => func(i.asInstanceOf[T])

// Returns a function to convert a value to pretty string. The function assumes input is not null.
protected final def castToString(from: DataType): Any => Any = from match {
protected final def castToString(from: DataType): Any => UTF8String = from match {
case CalendarIntervalType =>
acceptAny[CalendarInterval](i => UTF8String.fromString(i.toString))
case BinaryType if useHexFormatForBinary =>
Expand All @@ -72,7 +73,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
if (array.isNullAt(0)) {
if (nullString.nonEmpty) builder.append(nullString)
} else {
builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
builder.append(toUTF8String(array.get(0, et)))
}
var i = 1
while (i < array.numElements()) {
Expand All @@ -81,7 +82,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
builder.append(toUTF8String(array.get(i, et)))
}
i += 1
}
Expand All @@ -98,25 +99,24 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
val valueArray = map.valueArray()
val keyToUTF8String = castToString(kt)
val valueToUTF8String = castToString(vt)
builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String])
builder.append(keyToUTF8String(keyArray.get(0, kt)))
builder.append(" ->")
if (valueArray.isNullAt(0)) {
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String])
builder.append(valueToUTF8String(valueArray.get(0, vt)))
}
var i = 1
while (i < map.numElements()) {
builder.append(", ")
builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String])
builder.append(keyToUTF8String(keyArray.get(i, kt)))
builder.append(" ->")
if (valueArray.isNullAt(i)) {
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
builder.append(valueToUTF8String(valueArray.get(i, vt))
.asInstanceOf[UTF8String])
builder.append(valueToUTF8String(valueArray.get(i, vt)))
}
i += 1
}
Expand All @@ -134,7 +134,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
if (row.isNullAt(0)) {
if (nullString.nonEmpty) builder.append(nullString)
} else {
builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String])
builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))))
}
var i = 1
while (i < row.numFields) {
Expand All @@ -143,7 +143,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
if (nullString.nonEmpty) builder.append(" " + nullString)
} else {
builder.append(" ")
builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String])
builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))))
}
i += 1
}
Expand All @@ -162,7 +162,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField)))
case _: DecimalType if useDecimalPlainString =>
acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString))
case _: StringType => identity
case _: StringType => acceptAny[UTF8String](identity[UTF8String])
case _ => o => UTF8String.fromString(o.toString)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}

class ToPrettyStringSuite extends SparkFunSuite with ExpressionEvalHelper {

test("CalendarInterval as pretty strings") {
checkEvaluation(
ToPrettyString(Cast(Literal("interval -3 month 1 day 7 hours"), CalendarIntervalType)),
"-3 months 1 days 7 hours")
}

test("Binary as pretty strings") {
checkEvaluation(ToPrettyString(Cast(Literal("abcdef"), BinaryType)), "[61 62 63 64 65 66]")
}

test("Date as pretty strings") {
checkEvaluation(ToPrettyString(Cast(Literal("1980-12-17"), DateType, UTC_OPT)), "1980-12-17")
}

test("Timestamp as pretty strings") {
checkEvaluation(
ToPrettyString(Cast(Literal("2012-11-30 09:19:00"), TimestampType, UTC_OPT)),
"2012-11-30 01:19:00")
}

test("TimestampNTZ as pretty strings") {
checkEvaluation(ToPrettyString(Literal(1L, TimestampNTZType)), "1970-01-01 00:00:00.000001")
}

test("Array as pretty strings") {
checkEvaluation(ToPrettyString(Literal.create(Array(1, 2, 3, 4, 5))), "[1, 2, 3, 4, 5]")
}

test("Map as pretty strings") {
checkEvaluation(
ToPrettyString(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c"))),
"{1 -> a, 2 -> b, 3 -> c}")
}

test("Struct as pretty strings") {
checkEvaluation(ToPrettyString(Literal.create((1, "a", 0.1))), "{1, a, 0.1}")
checkEvaluation(
ToPrettyString(Literal.create(Tuple2[String, String](null, null))),
"{NULL, NULL}"
)
}

test("YearMonthInterval as pretty strings") {
checkEvaluation(
ToPrettyString(Cast(Literal("INTERVAL '1-0' YEAR TO MONTH"), YearMonthIntervalType())),
"INTERVAL '1-0' YEAR TO MONTH")
}

test("DayTimeInterval as pretty strings") {
checkEvaluation(
ToPrettyString(Cast(Literal("INTERVAL '1 2:03:04' DAY TO SECOND"), DayTimeIntervalType())),
"INTERVAL '1 02:03:04' DAY TO SECOND")
}

test("Decimal as pretty strings") {
checkEvaluation(
ToPrettyString(Cast(Literal(1234.65), DecimalType(6, 2))), "1234.65")
}

test("String as pretty strings") {
checkEvaluation(ToPrettyString(Literal("s")), "s")
}

test("Char as pretty strings") {
checkEvaluation(ToPrettyString(Literal.create('a', CharType(5))), "a")
}

test("Byte as pretty strings") {
checkEvaluation(ToPrettyString(Cast(Literal(8), ByteType)), "8")
}

test("Short as pretty strings") {
checkEvaluation(ToPrettyString(Cast(Literal(8), ShortType)), "8")
}

test("Int as pretty strings") {
checkEvaluation(ToPrettyString(Literal(1)), "1")
}

test("Long as pretty strings") {
checkEvaluation(ToPrettyString(Literal(1L)), "1")
}

test("Float as pretty strings") {
checkEvaluation(ToPrettyString(Cast(Literal(8), FloatType)), "8.0")
}

test("Double as pretty strings") {
checkEvaluation(ToPrettyString(Cast(Literal(8), DoubleType)), "8.0")
}

test("Boolean as pretty strings") {
checkEvaluation(ToPrettyString(Literal(false)), "false")
checkEvaluation(ToPrettyString(Literal(true)), "true")
}

test("Variant as pretty strings") {
checkEvaluation(
ToPrettyString(Literal(new VariantVal(Array[Byte](1, 2, 3), Array[Byte](4, 5)))),
UTF8String.fromBytes(Array[Byte](1, 2, 3)).toString)
}
}

0 comments on commit b341787

Please sign in to comment.