Skip to content

Commit

Permalink
[SPARK-49670][SQL] Enable trim collation for all passthrough expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Enabling usage of passthrough expressions for trim collation.
As with this change there will be more expressions that will support trim collation then those who don't in follow up default for support trim collation will be set on true.

**NOTE: it looks like a tons of changes but only changes are:
for each expression set supportsTrimCollation=true and add tests.**

### Why are the changes needed?
So that all expressions could be used with trim collation

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Add tests to CollationSqlExpressionsSuite

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #48739 from jovanpavl-db/implement_passthrough_functions.

Authored-by: Jovan Pavlovic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
jovanpavl-db authored and MaxGekk committed Dec 4, 2024
1 parent 74c3757 commit 10e0b61
Show file tree
Hide file tree
Showing 22 changed files with 634 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ abstract class TypeCoercionHelper {
}

case aj @ ArrayJoin(arr, d, nr)
if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) &&
if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)).
acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
implicitCast(arr, ArrayType(StringType, containsNull)) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ case class CallMethodViaReflection(
"requiredType" -> toSQLType(
TypeCollection(BooleanType, ByteType, ShortType,
IntegerType, LongType, FloatType, DoubleType,
StringTypeWithCollation)),
StringTypeWithCollation(supportsTrimCollation = true))),
"inputSql" -> toSQLExpr(e),
"inputType" -> toSQLType(e.dataType))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ object ExprUtils extends EvalHelper with QueryErrorsBase {

def convertToMapData(exp: Expression): Map[String, String] = exp match {
case m: CreateMap
if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)
if AbstractMapType(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true))
.acceptsType(m.dataType) =>
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ case class Reverse(child: Expression)
override def nullIntolerant: Boolean = true
// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeWithCollation, ArrayType))
Seq(TypeCollection(StringTypeWithCollation(supportsTrimCollation = true), ArrayType))

override def dataType: DataType = child.dataType

Expand Down Expand Up @@ -2127,12 +2127,12 @@ case class ArrayJoin(
this(array, delimiter, Some(nullReplacement))

override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
Seq(AbstractArrayType(StringTypeWithCollation),
StringTypeWithCollation,
StringTypeWithCollation)
Seq(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)),
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true))
} else {
Seq(AbstractArrayType(StringTypeWithCollation),
StringTypeWithCollation)
Seq(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)),
StringTypeWithCollation(supportsTrimCollation = true))
}

override def children: Seq[Expression] = if (nullReplacement.isDefined) {
Expand Down Expand Up @@ -2855,7 +2855,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
with QueryErrorsBase {

private def allowedTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, BinaryType, ArrayType)
Seq(StringTypeWithCollation(supportsTrimCollation = true), BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ case class CsvToStructs(
copy(timeZoneId = Option(timeZoneId))
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil
override def inputTypes: Seq[AbstractDataType] =
StringTypeWithCollation(supportsTrimCollation = true) :: Nil

override def prettyName: String = "from_csv"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
override def dataType: DataType = SQLConf.get.defaultStringType

override def inputTypes: Seq[AbstractDataType] =
Seq(TimestampType, StringTypeWithCollation)
Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true))

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down Expand Up @@ -1279,10 +1279,13 @@ abstract class ToTimestamp
override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(
StringTypeWithCollation, DateType, TimestampType, TimestampNTZType
),
StringTypeWithCollation)
Seq(
TypeCollection(
StringTypeWithCollation(supportsTrimCollation = true),
DateType,
TimestampType,
TimestampNTZType),
StringTypeWithCollation(supportsTrimCollation = true))

override def dataType: DataType = LongType
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true
Expand Down Expand Up @@ -1454,7 +1457,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] =
Seq(LongType, StringTypeWithCollation)
Seq(LongType, StringTypeWithCollation(supportsTrimCollation = true))

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down Expand Up @@ -1566,7 +1569,7 @@ case class NextDay(
def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, StringTypeWithCollation)
Seq(DateType, StringTypeWithCollation(supportsTrimCollation = true))

override def dataType: DataType = DateType
override def nullable: Boolean = true
Expand Down Expand Up @@ -1781,7 +1784,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes {
val funcName: String

override def inputTypes: Seq[AbstractDataType] =
Seq(TimestampType, StringTypeWithCollation)
Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = TimestampType

override def nullSafeEval(time: Any, timezone: Any): Any = {
Expand Down Expand Up @@ -2123,8 +2126,11 @@ case class ParseToDate(
// Note: ideally this function should only take string input, but we allow more types here to
// be backward compatible.
TypeCollection(
StringTypeWithCollation, DateType, TimestampType, TimestampNTZType) +:
format.map(_ => StringTypeWithCollation).toSeq
StringTypeWithCollation(supportsTrimCollation = true),
DateType,
TimestampType,
TimestampNTZType) +:
format.map(_ => StringTypeWithCollation(supportsTrimCollation = true)).toSeq
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2195,10 +2201,15 @@ case class ParseToTimestamp(
override def inputTypes: Seq[AbstractDataType] = {
// Note: ideally this function should only take string input, but we allow more types here to
// be backward compatible.
val types = Seq(StringTypeWithCollation, DateType, TimestampType, TimestampNTZType)
val types = Seq(
StringTypeWithCollation(
supportsTrimCollation = true),
DateType,
TimestampType,
TimestampNTZType)
TypeCollection(
(if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _*
) +: format.map(_ => StringTypeWithCollation).toSeq
) +: format.map(_ => StringTypeWithCollation(supportsTrimCollation = true)).toSeq
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2329,7 +2340,7 @@ case class TruncDate(date: Expression, format: Expression)
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, StringTypeWithCollation)
Seq(DateType, StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = DateType
override def prettyName: String = "trunc"
override val instant = date
Expand Down Expand Up @@ -2399,7 +2410,7 @@ case class TruncTimestamp(
override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, TimestampType)
Seq(StringTypeWithCollation(supportsTrimCollation = true), TimestampType)
override def dataType: TimestampType = TimestampType
override def prettyName: String = "date_trunc"
override val instant = timestamp
Expand Down Expand Up @@ -2800,7 +2811,7 @@ case class MakeTimestamp(
// casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0).
override def inputTypes: Seq[AbstractDataType] =
Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++
timezone.map(_ => StringTypeWithCollation)
timezone.map(_ => StringTypeWithCollation(supportsTrimCollation = true))
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
Expand Down Expand Up @@ -3333,7 +3344,10 @@ case class ConvertTimezone(
override def third: Expression = sourceTs

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, StringTypeWithCollation, TimestampNTZType)
Seq(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
TimestampNTZType)
override def dataType: DataType = TimestampNTZType

override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ case class GetJsonObject(json: Expression, path: Expression)
override def left: Expression = json
override def right: Expression = path
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, StringTypeWithCollation)
Seq(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = true
override def prettyName: String = "get_json_object"
Expand Down Expand Up @@ -490,7 +492,8 @@ case class JsonTuple(children: Seq[Expression])
)
} else if (
children.forall(
child => StringTypeWithCollation.acceptsType(child.dataType))) {
child => StringTypeWithCollation(supportsTrimCollation = true)
.acceptsType(child.dataType))) {
TypeCheckResult.TypeCheckSuccess
} else {
DataTypeMismatch(
Expand Down Expand Up @@ -709,7 +712,8 @@ case class JsonToStructs(
|""".stripMargin)
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil
override def inputTypes: Seq[AbstractDataType] =
StringTypeWithCollation(supportsTrimCollation = true) :: Nil

override def sql: String = schema match {
case _: MapType => "entries"
Expand Down Expand Up @@ -922,7 +926,8 @@ case class LengthOfJsonArray(child: Expression)
with ExpectsInputTypes
with RuntimeReplaceable {

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = IntegerType
override def nullable: Boolean = true
override def prettyName: String = "json_array_length"
Expand Down Expand Up @@ -967,7 +972,8 @@ case class JsonObjectKeys(child: Expression)
with ExpectsInputTypes
with RuntimeReplaceable {

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)
override def nullable: Boolean = true
override def prettyName: String = "json_object_keys"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ case class Mask(
*/
override def inputTypes: Seq[AbstractDataType] =
Seq(
StringTypeWithCollation,
StringTypeWithCollation,
StringTypeWithCollation,
StringTypeWithCollation,
StringTypeWithCollation)
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true))

override def nullable: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ case class Conv(
override def second: Expression = fromBaseExpr
override def third: Expression = toBaseExpr
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, IntegerType, IntegerType)
Seq(StringTypeWithCollation(supportsTrimCollation = true), IntegerType, IntegerType)
override def dataType: DataType = first.dataType
override def nullable: Boolean = true

Expand Down Expand Up @@ -1118,7 +1118,7 @@ case class Hex(child: Expression)
override def nullIntolerant: Boolean = true

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation))
Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation(supportsTrimCollation = true)))

override def dataType: DataType = child.dataType match {
case st: StringType => st
Expand Down Expand Up @@ -1163,7 +1163,8 @@ case class Unhex(child: Expression, failOnError: Boolean = false)

def this(expr: Expression) = this(expr, false)

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))

override def nullable: Boolean = true
override def dataType: DataType = BinaryType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType:
override def foldable: Boolean = false
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation))
Seq(
StringTypeWithCollation(supportsTrimCollation = true),
AbstractMapType(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true)
))

override def left: Expression = errorClass
override def right: Expression = errorParms
Expand Down Expand Up @@ -416,8 +421,8 @@ case class AesEncrypt(

override def inputTypes: Seq[AbstractDataType] =
Seq(BinaryType, BinaryType,
StringTypeWithCollation,
StringTypeWithCollation,
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
BinaryType, BinaryType)

override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad)
Expand Down Expand Up @@ -493,8 +498,8 @@ case class AesDecrypt(
override def inputTypes: Seq[AbstractDataType] = {
Seq(BinaryType,
BinaryType,
StringTypeWithCollation,
StringTypeWithCollation, BinaryType)
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true), BinaryType)
}

override def prettyName: String = "aes_decrypt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo
}

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, StringTypeWithCollation)
Seq(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true))

override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
Expand Down Expand Up @@ -288,7 +290,7 @@ case class ToCharacter(left: Expression, right: Expression)

override def dataType: DataType = SQLConf.get.defaultStringType
override def inputTypes: Seq[AbstractDataType] =
Seq(DecimalType, StringTypeWithCollation)
Seq(DecimalType, StringTypeWithCollation(supportsTrimCollation = true))
override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3557,9 +3557,9 @@ case class Sentences(
ArrayType(ArrayType(str.dataType, containsNull = false), containsNull = false)
override def inputTypes: Seq[AbstractDataType] =
Seq(
StringTypeWithCollation,
StringTypeWithCollation,
StringTypeWithCollation
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true)
)
override def first: Expression = str
override def second: Expression = language
Expand Down
Loading

0 comments on commit 10e0b61

Please sign in to comment.