From 10e0b619092b9683f08846409b4083dcd7624478 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Wed, 4 Dec 2024 15:06:26 +0100 Subject: [PATCH] [SPARK-49670][SQL] Enable trim collation for all passthrough expressions ### What changes were proposed in this pull request? Enabling usage of passthrough expressions for trim collation. As with this change there will be more expressions that will support trim collation then those who don't in follow up default for support trim collation will be set on true. **NOTE: it looks like a tons of changes but only changes are: for each expression set supportsTrimCollation=true and add tests.** ### Why are the changes needed? So that all expressions could be used with trim collation ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add tests to CollationSqlExpressionsSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #48739 from jovanpavl-db/implement_passthrough_functions. Authored-by: Jovan Pavlovic Signed-off-by: Max Gekk --- .../analysis/TypeCoercionHelper.scala | 3 +- .../expressions/CallMethodViaReflection.scala | 2 +- .../sql/catalyst/expressions/ExprUtils.scala | 4 +- .../expressions/collectionOperations.scala | 14 +- .../catalyst/expressions/csvExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 46 +- .../expressions/jsonExpressions.scala | 16 +- .../expressions/maskExpressions.scala | 10 +- .../expressions/mathExpressions.scala | 7 +- .../spark/sql/catalyst/expressions/misc.scala | 15 +- .../expressions/numberFormatExpressions.scala | 6 +- .../expressions/stringExpressions.scala | 6 +- .../catalyst/expressions/urlExpressions.scala | 12 +- .../variant/variantExpressions.scala | 5 +- .../sql/catalyst/expressions/xml/xpath.scala | 6 +- .../catalyst/expressions/xmlExpressions.scala | 3 +- .../analysis/AnsiTypeCoercionSuite.scala | 16 +- .../analyzer-results/collations.sql.out | 40 ++ .../resources/sql-tests/inputs/collations.sql | 5 + .../sql-tests/results/collations.sql.out | 110 +++++ .../sql/CollationSQLExpressionsSuite.scala | 400 ++++++++++++++++-- .../org/apache/spark/sql/CollationSuite.scala | 15 + 22 files changed, 634 insertions(+), 110 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala index 3fc4b71c986ff..ab2ab50cb33ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala @@ -318,7 +318,8 @@ abstract class TypeCoercionHelper { } case aj @ ArrayJoin(arr, d, nr) - if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) && + if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)). + acceptsType(arr.dataType) && ArrayType.acceptsType(arr.dataType) => val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull implicitCast(arr, ArrayType(StringType, containsNull)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index d38ee01485288..4eb14fb9e7b86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -115,7 +115,7 @@ case class CallMethodViaReflection( "requiredType" -> toSQLType( TypeCollection(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - StringTypeWithCollation)), + StringTypeWithCollation(supportsTrimCollation = true))), "inputSql" -> toSQLExpr(e), "inputType" -> toSQLType(e.dataType)) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index e65a0200b064f..8b7d641828ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -61,7 +61,9 @@ object ExprUtils extends EvalHelper with QueryErrorsBase { def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap - if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation) + if AbstractMapType( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) .acceptsType(m.dataType) => val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fb130574d3474..9843e844ad169 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1354,7 +1354,7 @@ case class Reverse(child: Expression) override def nullIntolerant: Boolean = true // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCollation, ArrayType)) + Seq(TypeCollection(StringTypeWithCollation(supportsTrimCollation = true), ArrayType)) override def dataType: DataType = child.dataType @@ -2127,12 +2127,12 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(AbstractArrayType(StringTypeWithCollation), - StringTypeWithCollation, - StringTypeWithCollation) + Seq(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)), + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) } else { - Seq(AbstractArrayType(StringTypeWithCollation), - StringTypeWithCollation) + Seq(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)), + StringTypeWithCollation(supportsTrimCollation = true)) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2855,7 +2855,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio with QueryErrorsBase { private def allowedTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, BinaryType, ArrayType) + Seq(StringTypeWithCollation(supportsTrimCollation = true), BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 21b6295a59f02..04fb9bc133c67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -87,7 +87,8 @@ case class CsvToStructs( copy(timeZoneId = Option(timeZoneId)) } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = + StringTypeWithCollation(supportsTrimCollation = true) :: Nil override def prettyName: String = "from_csv" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index fba3927a0bc9c..55e6c7f1503fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -971,7 +971,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = - Seq(TimestampType, StringTypeWithCollation) + Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true)) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1279,10 +1279,13 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection( - StringTypeWithCollation, DateType, TimestampType, TimestampNTZType - ), - StringTypeWithCollation) + Seq( + TypeCollection( + StringTypeWithCollation(supportsTrimCollation = true), + DateType, + TimestampType, + TimestampNTZType), + StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1454,7 +1457,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(LongType, StringTypeWithCollation) + Seq(LongType, StringTypeWithCollation(supportsTrimCollation = true)) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1566,7 +1569,7 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) override def inputTypes: Seq[AbstractDataType] = - Seq(DateType, StringTypeWithCollation) + Seq(DateType, StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1781,7 +1784,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes { val funcName: String override def inputTypes: Seq[AbstractDataType] = - Seq(TimestampType, StringTypeWithCollation) + Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2123,8 +2126,11 @@ case class ParseToDate( // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. TypeCollection( - StringTypeWithCollation, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringTypeWithCollation).toSeq + StringTypeWithCollation(supportsTrimCollation = true), + DateType, + TimestampType, + TimestampNTZType) +: + format.map(_ => StringTypeWithCollation(supportsTrimCollation = true)).toSeq } override protected def withNewChildrenInternal( @@ -2195,10 +2201,15 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringTypeWithCollation, DateType, TimestampType, TimestampNTZType) + val types = Seq( + StringTypeWithCollation( + supportsTrimCollation = true), + DateType, + TimestampType, + TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringTypeWithCollation).toSeq + ) +: format.map(_ => StringTypeWithCollation(supportsTrimCollation = true)).toSeq } override protected def withNewChildrenInternal( @@ -2329,7 +2340,7 @@ case class TruncDate(date: Expression, format: Expression) override def right: Expression = format override def inputTypes: Seq[AbstractDataType] = - Seq(DateType, StringTypeWithCollation) + Seq(DateType, StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2399,7 +2410,7 @@ case class TruncTimestamp( override def right: Expression = timestamp override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, TimestampType) + Seq(StringTypeWithCollation(supportsTrimCollation = true), TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2800,7 +2811,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringTypeWithCollation) + timezone.map(_ => StringTypeWithCollation(supportsTrimCollation = true)) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -3333,7 +3344,10 @@ case class ConvertTimezone( override def third: Expression = sourceTs override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation, TimestampNTZType) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 68cce1c2a138b..affc8261dc883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -133,7 +133,9 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -490,7 +492,8 @@ case class JsonTuple(children: Seq[Expression]) ) } else if ( children.forall( - child => StringTypeWithCollation.acceptsType(child.dataType))) { + child => StringTypeWithCollation(supportsTrimCollation = true) + .acceptsType(child.dataType))) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( @@ -709,7 +712,8 @@ case class JsonToStructs( |""".stripMargin) } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = + StringTypeWithCollation(supportsTrimCollation = true) :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -922,7 +926,8 @@ case class LengthOfJsonArray(child: Expression) with ExpectsInputTypes with RuntimeReplaceable { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" @@ -967,7 +972,8 @@ case class JsonObjectKeys(child: Expression) with ExpectsInputTypes with RuntimeReplaceable { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index 7be6df14194fc..5b17d2029ed1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -193,11 +193,11 @@ case class Mask( */ override def inputTypes: Seq[AbstractDataType] = Seq( - StringTypeWithCollation, - StringTypeWithCollation, - StringTypeWithCollation, - StringTypeWithCollation, - StringTypeWithCollation) + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 30f07dcc1e67e..317a08b8c64c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -455,7 +455,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, IntegerType, IntegerType) + Seq(StringTypeWithCollation(supportsTrimCollation = true), IntegerType, IntegerType) override def dataType: DataType = first.dataType override def nullable: Boolean = true @@ -1118,7 +1118,7 @@ case class Hex(child: Expression) override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation)) + Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation(supportsTrimCollation = true))) override def dataType: DataType = child.dataType match { case st: StringType => st @@ -1163,7 +1163,8 @@ case class Unhex(child: Expression, failOnError: Boolean = false) def this(expr: Expression) = this(expr, false) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def nullable: Boolean = true override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 622a0e0aa5bb7..fb30eab327d4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -85,7 +85,12 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + AbstractMapType( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true) + )) override def left: Expression = errorClass override def right: Expression = errorParms @@ -416,8 +421,8 @@ case class AesEncrypt( override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, - StringTypeWithCollation, - StringTypeWithCollation, + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), BinaryType, BinaryType) override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad) @@ -493,8 +498,8 @@ case class AesDecrypt( override def inputTypes: Seq[AbstractDataType] = { Seq(BinaryType, BinaryType, - StringTypeWithCollation, - StringTypeWithCollation, BinaryType) + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), BinaryType) } override def prettyName: String = "aes_decrypt" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index d4dcfdc5e72fb..fd6399d65271e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -51,7 +51,9 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo } override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -288,7 +290,7 @@ case class ToCharacter(left: Expression, right: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = - Seq(DecimalType, StringTypeWithCollation) + Seq(DecimalType, StringTypeWithCollation(supportsTrimCollation = true)) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2ea53350fea36..efd7e5c07de40 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -3557,9 +3557,9 @@ case class Sentences( ArrayType(ArrayType(str.dataType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = Seq( - StringTypeWithCollation, - StringTypeWithCollation, - StringTypeWithCollation + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true) ) override def first: Expression = str override def second: Expression = language diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 22dcd33937dfb..845ca0b608ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -57,13 +57,14 @@ case class UrlEncode(child: Expression) SQLConf.get.defaultStringType, "encode", Seq(child), - Seq(StringTypeWithCollation)) + Seq(StringTypeWithCollation(supportsTrimCollation = true))) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def prettyName: String = "url_encode" } @@ -96,13 +97,14 @@ case class UrlDecode(child: Expression, failOnError: Boolean = true) SQLConf.get.defaultStringType, "decode", Seq(child, Literal(failOnError)), - Seq(StringTypeWithCollation, BooleanType)) + Seq(StringTypeWithCollation(supportsTrimCollation = true), BooleanType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def prettyName: String = "url_decode" } @@ -211,7 +213,7 @@ case class ParseUrl( override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq.fill(children.size)(StringTypeWithCollation) + Seq.fill(children.size)(StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 06aec93912984..1639a161df4cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -66,7 +66,8 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) inputTypes :+ BooleanType :+ BooleanType, returnNullable = !failOnError) - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = + StringTypeWithCollation(supportsTrimCollation = true) :: Nil override def dataType: DataType = VariantType @@ -270,7 +271,7 @@ case class VariantGet( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET) override def inputTypes: Seq[AbstractDataType] = - Seq(VariantType, StringTypeWithCollation) + Seq(VariantType, StringTypeWithCollation(supportsTrimCollation = true)) override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 2c18ffa2abecb..2e591288a21cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -41,7 +41,9 @@ abstract class XPathExtract override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, StringTypeWithCollation) + Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true)) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -49,7 +51,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringTypeWithCollation), + "inputType" -> toSQLType(StringTypeWithCollation(supportsTrimCollation = true)), "inputExpr" -> toSQLExpr(path) ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 66f7f25e4abe8..d8254f04b4d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -126,7 +126,8 @@ case class XmlToStructs( defineCodeGen(ctx, ev, input => s"(InternalRow) $expr.nullSafeEval($input)") } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = + StringTypeWithCollation(supportsTrimCollation = true) :: Nil override def prettyName: String = "from_xml" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index 8cf7d78b510be..139e89828f8e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -1057,11 +1057,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(IntegerType)) shouldCast( ArrayType(StringType), - AbstractArrayType(StringTypeWithCollation), + AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)), ArrayType(StringType)) shouldCast( ArrayType(IntegerType), - AbstractArrayType(StringTypeWithCollation), + AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)), ArrayType(StringType)) shouldCast( ArrayType(StringType), @@ -1075,11 +1075,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(ArrayType(IntegerType))) shouldCast( ArrayType(ArrayType(StringType)), - AbstractArrayType(AbstractArrayType(StringTypeWithCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true))), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(IntegerType)), - AbstractArrayType(AbstractArrayType(StringTypeWithCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true))), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(StringType)), @@ -1088,16 +1088,16 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { // Invalid casts involving casting arrays into non-complex types. shouldNotCast(ArrayType(IntegerType), IntegerType) - shouldNotCast(ArrayType(StringType), StringTypeWithCollation) + shouldNotCast(ArrayType(StringType), StringTypeWithCollation(supportsTrimCollation = true)) shouldNotCast(ArrayType(StringType), IntegerType) - shouldNotCast(ArrayType(IntegerType), StringTypeWithCollation) + shouldNotCast(ArrayType(IntegerType), StringTypeWithCollation(supportsTrimCollation = true)) // Invalid casts involving casting arrays of arrays into arrays of non-complex types. shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(IntegerType)) shouldNotCast(ArrayType(ArrayType(StringType)), - AbstractArrayType(StringTypeWithCollation)) + AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true))) shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(IntegerType)) shouldNotCast(ArrayType(ArrayType(IntegerType)), - AbstractArrayType(StringTypeWithCollation)) + AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true))) } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 0d5c414416d40..7a4777c34fed6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -2143,6 +2143,14 @@ Project [octet_length(collate(utf8_binary#x, utf8_lcase)) AS octet_length(collat +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select octet_length(utf8_binary collate utf8_lcase_rtrim), octet_length(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [octet_length(collate(utf8_binary#x, utf8_lcase_rtrim)) AS octet_length(collate(utf8_binary, utf8_lcase_rtrim))#x, octet_length(collate(utf8_lcase#x, utf8_binary_rtrim)) AS octet_length(collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select luhn_check(num) from t9 -- !query analysis @@ -2233,6 +2241,14 @@ Project [is_valid_utf8(collate(utf8_binary#x, utf8_lcase)) AS is_valid_utf8(coll +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select is_valid_utf8(utf8_binary collate utf8_lcase_rtrim), is_valid_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [is_valid_utf8(collate(utf8_binary#x, utf8_lcase_rtrim)) AS is_valid_utf8(collate(utf8_binary, utf8_lcase_rtrim))#x, is_valid_utf8(collate(utf8_lcase#x, utf8_binary_rtrim)) AS is_valid_utf8(collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5 -- !query analysis @@ -2249,6 +2265,14 @@ Project [make_valid_utf8(collate(utf8_binary#x, utf8_lcase)) AS make_valid_utf8( +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select make_valid_utf8(utf8_binary collate utf8_lcase_rtrim), make_valid_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [make_valid_utf8(collate(utf8_binary#x, utf8_lcase_rtrim)) AS make_valid_utf8(collate(utf8_binary, utf8_lcase_rtrim))#x, make_valid_utf8(collate(utf8_lcase#x, utf8_binary_rtrim)) AS make_valid_utf8(collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5 -- !query analysis @@ -2265,6 +2289,14 @@ Project [validate_utf8(collate(utf8_binary#x, utf8_lcase)) AS validate_utf8(coll +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select validate_utf8(utf8_binary collate utf8_lcase_rtrim), validate_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [validate_utf8(collate(utf8_binary#x, utf8_lcase_rtrim)) AS validate_utf8(collate(utf8_binary, utf8_lcase_rtrim))#x, validate_utf8(collate(utf8_lcase#x, utf8_binary_rtrim)) AS validate_utf8(collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5 -- !query analysis @@ -2281,6 +2313,14 @@ Project [try_validate_utf8(collate(utf8_binary#x, utf8_lcase)) AS try_validate_u +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select try_validate_utf8(utf8_binary collate utf8_lcase_rtrim), try_validate_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query analysis +Project [try_validate_utf8(collate(utf8_binary#x, utf8_lcase_rtrim)) AS try_validate_utf8(collate(utf8_binary, utf8_lcase_rtrim))#x, try_validate_utf8(collate(utf8_lcase#x, utf8_binary_rtrim)) AS try_validate_utf8(collate(utf8_lcase, utf8_binary_rtrim))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + -- !query select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index b4d33bb0196c9..df15adf2f8fe4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -328,6 +328,7 @@ select bit_length(utf8_binary), bit_length(utf8_lcase) from t5; select bit_length(utf8_binary collate utf8_lcase), bit_length(utf8_lcase collate utf8_binary) from t5; select octet_length(utf8_binary), octet_length(utf8_lcase) from t5; select octet_length(utf8_binary collate utf8_lcase), octet_length(utf8_lcase collate utf8_binary) from t5; +select octet_length(utf8_binary collate utf8_lcase_rtrim), octet_length(utf8_lcase collate utf8_binary_rtrim) from t5; -- Luhncheck select luhn_check(num) from t9; @@ -344,18 +345,22 @@ select levenshtein(utf8_binary, 'AaAA' collate utf8_lcase, 3), levenshtein(utf8_ -- IsValidUTF8 select is_valid_utf8(utf8_binary), is_valid_utf8(utf8_lcase) from t5; select is_valid_utf8(utf8_binary collate utf8_lcase), is_valid_utf8(utf8_lcase collate utf8_binary) from t5; +select is_valid_utf8(utf8_binary collate utf8_lcase_rtrim), is_valid_utf8(utf8_lcase collate utf8_binary_rtrim) from t5; -- MakeValidUTF8 select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5; select make_valid_utf8(utf8_binary collate utf8_lcase), make_valid_utf8(utf8_lcase collate utf8_binary) from t5; +select make_valid_utf8(utf8_binary collate utf8_lcase_rtrim), make_valid_utf8(utf8_lcase collate utf8_binary_rtrim) from t5; -- ValidateUTF8 select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5; select validate_utf8(utf8_binary collate utf8_lcase), validate_utf8(utf8_lcase collate utf8_binary) from t5; +select validate_utf8(utf8_binary collate utf8_lcase_rtrim), validate_utf8(utf8_lcase collate utf8_binary_rtrim) from t5; -- TryValidateUTF8 select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5; select try_validate_utf8(utf8_binary collate utf8_lcase), try_validate_utf8(utf8_lcase collate utf8_binary) from t5; +select try_validate_utf8(utf8_binary collate utf8_lcase_rtrim), try_validate_utf8(utf8_lcase collate utf8_binary_rtrim) from t5; -- Left/Right/Substr select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5; diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index e96549f00d6ec..fbfde3d78c1be 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -3583,6 +3583,28 @@ struct +-- !query output +23 23 +29 29 +3 3 +3 3 +3 4 +3 4 +4 3 +4 4 +5 3 +6 7 +7 7 +8 1 +8 24 +8 8 +8 8 + + -- !query select luhn_check(num) from t9 -- !query schema @@ -3776,6 +3798,28 @@ true true true true +-- !query +select is_valid_utf8(utf8_binary collate utf8_lcase_rtrim), is_valid_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query schema +struct +-- !query output +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true + + -- !query select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5 -- !query schema @@ -3820,6 +3864,28 @@ kitten sitTing İo İo +-- !query +select make_valid_utf8(utf8_binary collate utf8_lcase_rtrim), make_valid_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo +İo İo +İo İo + + -- !query select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5 -- !query schema @@ -3864,6 +3930,28 @@ kitten sitTing İo İo +-- !query +select validate_utf8(utf8_binary collate utf8_lcase_rtrim), validate_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo +İo İo +İo İo + + -- !query select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5 -- !query schema @@ -3908,6 +3996,28 @@ kitten sitTing İo İo +-- !query +select try_validate_utf8(utf8_binary collate utf8_lcase_rtrim), try_validate_utf8(utf8_lcase collate utf8_binary_rtrim) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo +İo İo +İo İo + + -- !query select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 4e91fd721a075..384411a0fd342 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -39,8 +39,18 @@ class CollationSQLExpressionsSuite with SharedSparkSession with ExpressionEvalHelper { - private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI") - private val testAdditionalCollations = Seq("UNICODE", "SR", "SR_CI", "SR_AI", "SR_CI_AI") + private val testSuppCollations = + Seq( + "UTF8_BINARY", + "UTF8_BINARY_RTRIM", + "UTF8_LCASE", + "UTF8_LCASE_RTRIM", + "UNICODE", + "UNICODE_RTRIM", + "UNICODE_CI", + "UNICODE_CI_RTRIM") + private val testAdditionalCollations = Seq("UNICODE", + "SR", "SR_RTRIM", "SR_CI", "SR_AI", "SR_CI_AI") private val fullyQualifiedPrefix = s"${CollationFactory.CATALOG}.${CollationFactory.SCHEMA}." test("Support Md5 hash expression with collation") { @@ -264,11 +274,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( UrlEncodeTestCase("https://spark.apache.org", "UTF8_BINARY", "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UTF8_BINARY_RTRIM", + "https%3A%2F%2Fspark.apache.org"), UrlEncodeTestCase("https://spark.apache.org", "UTF8_LCASE", "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UTF8_LCASE_RTRIM", + "https%3A%2F%2Fspark.apache.org"), UrlEncodeTestCase("https://spark.apache.org", "UNICODE", "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UNICODE_RTRIM", + "https%3A%2F%2Fspark.apache.org"), UrlEncodeTestCase("https://spark.apache.org", "UNICODE_CI", + "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UNICODE_CI_RTRIM", "https%3A%2F%2Fspark.apache.org") ) @@ -298,11 +316,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UTF8_BINARY", "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UTF8_BINARY_RTRIM", + "https://spark.apache.org"), UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UTF8_LCASE", "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UTF8_LCASE_RTRIM", + "https://spark.apache.org"), UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UNICODE", "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UNICODE_RTRIM", + "https://spark.apache.org"), UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UNICODE_CI", + "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UNICODE_CI_RTRIM", "https://spark.apache.org") ) @@ -333,11 +359,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( ParseUrlTestCase("http://spark.apache.org/path?query=1", "UTF8_BINARY", "HOST", "spark.apache.org"), + ParseUrlTestCase("http://spark.apache.org/path?query=1", "UTF8_BINARY_RTRIM", "HOST", + "spark.apache.org"), ParseUrlTestCase("http://spark.apache.org/path?query=2", "UTF8_LCASE", "PATH", "/path"), + ParseUrlTestCase("http://spark.apache.org/path?query=2", "UTF8_LCASE_RTRIM", "PATH", + "/path"), ParseUrlTestCase("http://spark.apache.org/path?query=3", "UNICODE", "QUERY", "query=3"), + ParseUrlTestCase("http://spark.apache.org/path?query=3", "UNICODE_RTRIM", "QUERY", + "query=3"), ParseUrlTestCase("http://spark.apache.org/path?query=4", "UNICODE_CI", "PROTOCOL", + "http"), + ParseUrlTestCase("http://spark.apache.org/path?query=4", "UNICODE_CI_RTRIM", "PROTOCOL", "http") ) @@ -372,11 +406,20 @@ class CollationSQLExpressionsSuite Row(1), Seq( StructField("a", IntegerType, nullable = true) )), + CsvToStructsTestCase("1", "UTF8_BINARY_RTRIM", "'a INT'", "", + Row(1), Seq( + StructField("a", IntegerType, nullable = true) + )), CsvToStructsTestCase("true, 0.8", "UTF8_LCASE", "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( StructField("A", BooleanType, nullable = true), StructField("B", DoubleType, nullable = true) )), + CsvToStructsTestCase("true, 0.8", "UTF8_LCASE_RTRIM", "'A BOOLEAN, B DOUBLE'", "", + Row(true, 0.8), Seq( + StructField("A", BooleanType, nullable = true), + StructField("B", DoubleType, nullable = true) + )), CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "", Row("Spark"), Seq( StructField("a", StringType, nullable = true) @@ -385,6 +428,10 @@ class CollationSQLExpressionsSuite Row("Spark"), Seq( StructField("a", StringType("UNICODE"), nullable = true) )), + CsvToStructsTestCase("\"Spark\"", "UNICODE_RTRIM", "'a STRING COLLATE UNICODE_RTRIM'", "", + Row("Spark"), Seq( + StructField("a", StringType("UNICODE_RTRIM"), nullable = true) + )), CsvToStructsTestCase("26/08/2015", "UTF8_BINARY", "'time Timestamp'", ", map('timestampFormat', 'dd/MM/yyyy')", Row( new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") @@ -419,10 +466,16 @@ class CollationSQLExpressionsSuite val testCases = Seq( SchemaOfCsvTestCase("1", "UTF8_BINARY", "STRUCT<_c0: INT>"), + SchemaOfCsvTestCase("1", "UTF8_BINARY_RTRIM", "STRUCT<_c0: INT>"), SchemaOfCsvTestCase("true,0.8", "UTF8_LCASE", "STRUCT<_c0: BOOLEAN, _c1: DOUBLE>"), + SchemaOfCsvTestCase("true,0.8", "UTF8_LCASE_RTRIM", + "STRUCT<_c0: BOOLEAN, _c1: DOUBLE>"), SchemaOfCsvTestCase("2015-08-26", "UNICODE", "STRUCT<_c0: DATE>"), + SchemaOfCsvTestCase("2015-08-26", "UNICODE_RTRIM", "STRUCT<_c0: DATE>"), SchemaOfCsvTestCase("abc", "UNICODE_CI", + "STRUCT<_c0: STRING>"), + SchemaOfCsvTestCase("abc", "UNICODE_CI_RTRIM", "STRUCT<_c0: STRING>") ) @@ -451,9 +504,14 @@ class CollationSQLExpressionsSuite val testCases = Seq( StructsToCsvTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY", "1,2"), + StructsToCsvTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY_RTRIM", "1,2"), StructsToCsvTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_LCASE", "true,2.0"), + StructsToCsvTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_LCASE_RTRIM", "true,2.0"), StructsToCsvTestCase("named_struct()", "UNICODE", null), + StructsToCsvTestCase("named_struct()", "UNICODE_RTRIM", null), StructsToCsvTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI", + "2015-08-26T00:00:00.000-07:00"), + StructsToCsvTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI_RTRIM", "2015-08-26T00:00:00.000-07:00") ) @@ -484,9 +542,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( ConvTestCase("100", "2", "10", "UTF8_BINARY", "4"), + ConvTestCase("100", "2", "10", "UTF8_BINARY_RTRIM", "4"), ConvTestCase("100", "2", "10", "UTF8_LCASE", "4"), + ConvTestCase("100", "2", "10", "UTF8_LCASE_RTRIM", "4"), ConvTestCase("100", "2", "10", "UNICODE", "4"), - ConvTestCase("100", "2", "10", "UNICODE_CI", "4") + ConvTestCase("100", "2", "10", "UNICODE_RTRIM", "4"), + ConvTestCase("100", "2", "10", "UNICODE_CI", "4"), + ConvTestCase("100", "2", "10", "UNICODE_CI_RTRIM", "4") ) testCases.foreach(t => { val query = @@ -508,9 +570,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( BinTestCase("13", "UTF8_BINARY", "1101"), + BinTestCase("13", "UTF8_BINARY_RTRIM", "1101"), BinTestCase("13", "UTF8_LCASE", "1101"), + BinTestCase("13", "UTF8_LCASE_RTRIM", "1101"), BinTestCase("13", "UNICODE", "1101"), - BinTestCase("13", "UNICODE_CI", "1101") + BinTestCase("13", "UNICODE_RTRIM", "1101"), + BinTestCase("13", "UNICODE_CI", "1101"), + BinTestCase("13", "UNICODE_CI_RTRIM", "1101") ) testCases.foreach(t => { val query = @@ -533,9 +599,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( HexTestCase("13", "UTF8_BINARY", "D"), + HexTestCase("13", "UTF8_BINARY_RTRIM", "D"), HexTestCase("13", "UTF8_LCASE", "D"), + HexTestCase("13", "UTF8_LCASE_RTRIM", "D"), HexTestCase("13", "UNICODE", "D"), - HexTestCase("13", "UNICODE_CI", "D") + HexTestCase("13", "UNICODE_RTRIM", "D"), + HexTestCase("13", "UNICODE_CI", "D"), + HexTestCase("13", "UNICODE_CI_RTRIM", "D") ) testCases.foreach(t => { val query = @@ -558,10 +628,15 @@ class CollationSQLExpressionsSuite val testCases = Seq( HexTestCase("Spark SQL", "UTF8_BINARY", "537061726B2053514C"), + HexTestCase("Spark SQL", "UTF8_BINARY_RTRIM", "537061726B2053514C"), HexTestCase("Spark SQL", "UTF8_LCASE", "537061726B2053514C"), + HexTestCase("Spark SQL", "UTF8_LCASE_RTRIM", "537061726B2053514C"), HexTestCase("Spark SQL", "UNICODE", "537061726B2053514C"), + HexTestCase("Spark SQL", "UNICODE_RTRIM", "537061726B2053514C"), HexTestCase("Spark SQL", "UNICODE_CI", "537061726B2053514C"), - HexTestCase("Spark SQL", "DE_CI_AI", "537061726B2053514C") + HexTestCase("Spark SQL", "UNICODE_CI_RTRIM", "537061726B2053514C"), + HexTestCase("Spark SQL", "DE_CI_AI", "537061726B2053514C"), + HexTestCase("Spark SQL", "DE_CI_AI_RTRIM", "537061726B2053514C") ) testCases.foreach(t => { val query = @@ -582,9 +657,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( UnHexTestCase("537061726B2053514C", "UTF8_BINARY", "Spark SQL"), + UnHexTestCase("537061726B2053514C", "UTF8_BINARY_RTRIM", "Spark SQL"), UnHexTestCase("537061726B2053514C", "UTF8_LCASE", "Spark SQL"), + UnHexTestCase("537061726B2053514C", "UTF8_LCASE_RTRIM", "Spark SQL"), UnHexTestCase("537061726B2053514C", "UNICODE", "Spark SQL"), + UnHexTestCase("537061726B2053514C", "UNICODE_RTRIM", "Spark SQL"), UnHexTestCase("537061726B2053514C", "UNICODE_CI", "Spark SQL"), + UnHexTestCase("537061726B2053514C", "UNICODE_CI_RTRIM", "Spark SQL"), UnHexTestCase("537061726B2053514C", "DE", "Spark SQL") ) testCases.foreach(t => { @@ -613,16 +692,30 @@ class CollationSQLExpressionsSuite "xpath_boolean", "UTF8_BINARY", true, BooleanType), XPathTestCase("12", "sum(A/B)", "xpath_short", "UTF8_BINARY", 3, ShortType), + XPathTestCase("1", "a/b", + "xpath_boolean", "UTF8_BINARY_RTRIM", true, BooleanType), + XPathTestCase("12", "sum(A/B)", + "xpath_short", "UTF8_BINARY_RTRIM", 3, ShortType), XPathTestCase("34", "sum(a/b)", "xpath_int", "UTF8_LCASE", 7, IntegerType), XPathTestCase("56", "sum(A/B)", "xpath_long", "UTF8_LCASE", 11, LongType), + XPathTestCase("34", "sum(a/b)", + "xpath_int", "UTF8_LCASE_RTRIM", 7, IntegerType), + XPathTestCase("56", "sum(A/B)", + "xpath_long", "UTF8_LCASE_RTRIM", 11, LongType), XPathTestCase("78", "sum(a/b)", "xpath_float", "UNICODE", 15.0, FloatType), XPathTestCase("90", "sum(A/B)", "xpath_double", "UNICODE", 9.0, DoubleType), + XPathTestCase("78", "sum(a/b)", + "xpath_float", "UNICODE_RTRIM", 15.0, FloatType), + XPathTestCase("90", "sum(A/B)", + "xpath_double", "UNICODE_RTRIM", 9.0, DoubleType), XPathTestCase("bcc", "a/c", "xpath_string", "UNICODE_CI", "cc", StringType("UNICODE_CI")), + XPathTestCase("bcc ", "a/c", + "xpath_string", "UNICODE_CI_RTRIM", "cc ", StringType("UNICODE_CI_RTRIM")), XPathTestCase("b1b2b3c1c2", "a/b/text()", "xpath", "UNICODE_CI", Array("b1", "b2", "b3"), ArrayType(StringType("UNICODE_CI"))) ) @@ -651,10 +744,15 @@ class CollationSQLExpressionsSuite val testCases = Seq( StringSpaceTestCase(1, "UTF8_BINARY", " "), + StringSpaceTestCase(1, "UTF8_BINARY_RTRIM", " "), StringSpaceTestCase(2, "UTF8_LCASE", " "), + StringSpaceTestCase(2, "UTF8_LCASE_RTRIM", " "), StringSpaceTestCase(3, "UNICODE", " "), + StringSpaceTestCase(3, "UNICODE_RTRIM", " "), StringSpaceTestCase(4, "UNICODE_CI", " "), - StringSpaceTestCase(5, "AF_CI_AI", " ") + StringSpaceTestCase(4, "UNICODE_CI_RTRIM", " "), + StringSpaceTestCase(5, "AF_CI_AI", " "), + StringSpaceTestCase(5, "AF_CI_AI_RTRIM", " ") ) // Supported collations @@ -684,9 +782,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( ToNumberTestCase("123", "UTF8_BINARY", "999", 123, DecimalType(3, 0)), + ToNumberTestCase("123", "UTF8_BINARY_RTRIM", "999", 123, DecimalType(3, 0)), ToNumberTestCase("1", "UTF8_LCASE", "0.00", 1.00, DecimalType(3, 2)), + ToNumberTestCase("1", "UTF8_LCASE_RTRIM", "0.00", 1.00, DecimalType(3, 2)), ToNumberTestCase("99,999", "UNICODE", "99,999", 99999, DecimalType(5, 0)), - ToNumberTestCase("$14.99", "UNICODE_CI", "$99.99", 14.99, DecimalType(4, 2)) + ToNumberTestCase("99,999", "UNICODE_RTRIM", "99,999", 99999, DecimalType(5, 0)), + ToNumberTestCase("$14.99", "UNICODE_CI", "$99.99", 14.99, DecimalType(4, 2)), + ToNumberTestCase("$14.99", "UNICODE_CI_RTRIM", "$99.99", 14.99, DecimalType(4, 2)) ) // Supported collations (ToNumber) @@ -754,9 +856,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( ToCharTestCase(12, "UTF8_BINARY", "999", " 12"), + ToCharTestCase(12, "UTF8_BINARY_RTRIM", "999", " 12"), ToCharTestCase(34, "UTF8_LCASE", "000D00", "034.00"), + ToCharTestCase(34, "UTF8_LCASE_RTRIM", "000D00", "034.00"), ToCharTestCase(56, "UNICODE", "$99.99", "$56.00"), - ToCharTestCase(78, "UNICODE_CI", "99D9S", "78.0+") + ToCharTestCase(56, "UNICODE_RTRIM", "$99.99", "$56.00"), + ToCharTestCase(78, "UNICODE_CI", "99D9S", "78.0+"), + ToCharTestCase(78, "UNICODE_CI_RTRIM", "99D9S", "78.0+") ) // Supported collations @@ -785,9 +891,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( GetJsonObjectTestCase("{\"a\":\"b\"}", "$.a", "UTF8_BINARY", "b"), + GetJsonObjectTestCase("{\"a\":\"b\"}", "$.a", "UTF8_BINARY_RTRIM", "b"), GetJsonObjectTestCase("{\"A\":\"1\"}", "$.A", "UTF8_LCASE", "1"), + GetJsonObjectTestCase("{\"A\":\"1\"}", "$.A", "UTF8_LCASE_RTRIM", "1"), GetJsonObjectTestCase("{\"x\":true}", "$.x", "UNICODE", "true"), - GetJsonObjectTestCase("{\"X\":1}", "$.X", "UNICODE_CI", "1") + GetJsonObjectTestCase("{\"x\":true}", "$.x", "UNICODE_RTRIM", "true"), + GetJsonObjectTestCase("{\"X\":1}", "$.X", "UNICODE_CI", "1"), + GetJsonObjectTestCase("{\"X\":1}", "$.X", "UNICODE_CI_RTRIM", "1") ) // Supported collations @@ -817,10 +927,16 @@ class CollationSQLExpressionsSuite val testCases = Seq( JsonTupleTestCase("{\"a\":1, \"b\":2}", "'a', 'b'", "UTF8_BINARY", Row("1", "2")), + JsonTupleTestCase("{\"a\":1, \"b\":2}", "'a', 'b'", "UTF8_BINARY_RTRIM", + Row("1", "2")), JsonTupleTestCase("{\"A\":\"3\", \"B\":\"4\"}", "'A', 'B'", "UTF8_LCASE", Row("3", "4")), + JsonTupleTestCase("{\"A\":\"3\", \"B\":\"4\"}", "'A', 'B'", "UTF8_LCASE_RTRIM", + Row("3", "4")), JsonTupleTestCase("{\"x\":true, \"y\":false}", "'x', 'y'", "UNICODE", Row("true", "false")), + JsonTupleTestCase("{\"x\":true, \"y\":false}", "'x', 'y'", "UNICODE_RTRIM", + Row("true", "false")), JsonTupleTestCase("{\"X\":null, \"Y\":null}", "'X', 'Y'", "UNICODE_CI", Row(null, null)) ) @@ -852,12 +968,20 @@ class CollationSQLExpressionsSuite val testCases = Seq( JsonToStructsTestCase("{\"a\":1, \"b\":2.0}", "a INT, b DOUBLE", "UTF8_BINARY", Row(Row(1, 2.0))), + JsonToStructsTestCase("{\"a\":1, \"b\":2.0}", "a INT, b DOUBLE", + "UTF8_BINARY_RTRIM", Row(Row(1, 2.0))), JsonToStructsTestCase("{\"A\":\"3\", \"B\":4}", "A STRING COLLATE UTF8_LCASE, B INT", "UTF8_LCASE", Row(Row("3", 4))), + JsonToStructsTestCase("{\"A\":\"3\", \"B\":4}", "A STRING COLLATE UTF8_LCASE, B INT", + "UTF8_LCASE_RTRIM", Row(Row("3", 4))), JsonToStructsTestCase("{\"x\":true, \"y\":null}", "x BOOLEAN, y VOID", "UNICODE", Row(Row(true, null))), + JsonToStructsTestCase("{\"x\":true, \"y\":null}", "x BOOLEAN, y VOID", + "UNICODE_RTRIM", Row(Row(true, null))), + JsonToStructsTestCase("{\"X\":null, \"Y\":false}", "X VOID, Y BOOLEAN", + "UNICODE_CI", Row(Row(null, false))), JsonToStructsTestCase("{\"X\":null, \"Y\":false}", "X VOID, Y BOOLEAN", - "UNICODE_CI", Row(Row(null, false))) + "UNICODE_CI_RTRIM", Row(Row(null, false))) ) // Supported collations @@ -886,12 +1010,20 @@ class CollationSQLExpressionsSuite val testCases = Seq( StructsToJsonTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY", Row("{\"a\":1,\"b\":2}")), + StructsToJsonTestCase("named_struct('a', 1, 'b', 2)", + "UTF8_BINARY_RTRIM", Row("{\"a\":1,\"b\":2}")), StructsToJsonTestCase("array(named_struct('a', 1, 'b', 2))", "UTF8_LCASE", Row("[{\"a\":1,\"b\":2}]")), + StructsToJsonTestCase("array(named_struct('a', 1, 'b', 2))", + "UTF8_LCASE_RTRIM", Row("[{\"a\":1,\"b\":2}]")), StructsToJsonTestCase("map('a', named_struct('b', 1))", "UNICODE", Row("{\"a\":{\"b\":1}}")), + StructsToJsonTestCase("map('a', named_struct('b', 1))", + "UNICODE_RTRIM", Row("{\"a\":{\"b\":1}}")), StructsToJsonTestCase("array(map('a', 1))", - "UNICODE_CI", Row("[{\"a\":1}]")) + "UNICODE_CI", Row("[{\"a\":1}]")), + StructsToJsonTestCase("array(map('a', 1))", + "UNICODE_CI_RTRIM", Row("[{\"a\":1}]")) ) // Supported collations @@ -919,9 +1051,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( LengthOfJsonArrayTestCase("'[1,2,3,4]'", "UTF8_BINARY", Row(4)), + LengthOfJsonArrayTestCase("'[1,2,3,4]'", "UTF8_BINARY_RTRIM", Row(4)), LengthOfJsonArrayTestCase("'[1,2,3,{\"f1\":1,\"f2\":[5,6]},4]'", "UTF8_LCASE", Row(5)), + LengthOfJsonArrayTestCase("'[1,2,3,{\"f1\":1,\"f2\":[5,6]},4]'", "UTF8_LCASE_RTRIM", Row(5)), LengthOfJsonArrayTestCase("'[1,2'", "UNICODE", Row(null)), - LengthOfJsonArrayTestCase("'['", "UNICODE_CI", Row(null)) + LengthOfJsonArrayTestCase("'[1,2'", "UNICODE_RTRIM", Row(null)), + LengthOfJsonArrayTestCase("'['", "UNICODE_CI", Row(null)), + LengthOfJsonArrayTestCase("'['", "UNICODE_CI_RTRIM", Row(null)) ) // Supported collations @@ -949,11 +1085,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( JsonObjectKeysJsonArrayTestCase("{}", "UTF8_BINARY", Row(Seq())), + JsonObjectKeysJsonArrayTestCase("{}", "UTF8_BINARY_RTRIM", + Row(Seq())), JsonObjectKeysJsonArrayTestCase("{\"k\":", "UTF8_LCASE", Row(null)), + JsonObjectKeysJsonArrayTestCase("{\"k\":", "UTF8_LCASE_RTRIM", + Row(null)), JsonObjectKeysJsonArrayTestCase("{\"k1\": \"v1\"}", "UNICODE", Row(Seq("k1"))), + JsonObjectKeysJsonArrayTestCase("{\"k1\": \"v1\"}", "UNICODE_RTRIM", + Row(Seq("k1"))), JsonObjectKeysJsonArrayTestCase("{\"k1\":1,\"k2\":{\"k3\":3, \"k4\":4}}", "UNICODE_CI", + Row(Seq("k1", "k2"))), + JsonObjectKeysJsonArrayTestCase("{\"k1\":1,\"k2\":{\"k3\":3, \"k4\":4}}", "UNICODE_CI_RTRIM", Row(Seq("k1", "k2"))) ) @@ -983,12 +1127,20 @@ class CollationSQLExpressionsSuite val testCases = Seq( SchemaOfJsonTestCase("'[{\"col\":0}]'", "UTF8_BINARY", Row("ARRAY>")), + SchemaOfJsonTestCase("'[{\"col\":0}]'", + "UTF8_BINARY_RTRIM", Row("ARRAY>")), SchemaOfJsonTestCase("'[{\"col\":01}]', map('allowNumericLeadingZeros', 'true')", "UTF8_LCASE", Row("ARRAY>")), + SchemaOfJsonTestCase("'[{\"col\":01}]', map('allowNumericLeadingZeros', 'true')", + "UTF8_LCASE_RTRIM", Row("ARRAY>")), SchemaOfJsonTestCase("'[]'", "UNICODE", Row("ARRAY")), + SchemaOfJsonTestCase("'[]'", + "UNICODE_RTRIM", Row("ARRAY")), + SchemaOfJsonTestCase("''", + "UNICODE_CI", Row("STRING")), SchemaOfJsonTestCase("''", - "UNICODE_CI", Row("STRING")) + "UNICODE_CI_RTRIM", Row("STRING")) ) // Supported collations @@ -1029,10 +1181,7 @@ class CollationSQLExpressionsSuite Map("c" -> "1", "č" -> "2", "ć" -> "3")) ) val unsupportedTestCases = Seq( - StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null), - StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null), - StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null), - StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null)) + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null)) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -1079,9 +1228,13 @@ class CollationSQLExpressionsSuite case class RaiseErrorTestCase(errorMessage: String, collationName: String) val testCases = Seq( RaiseErrorTestCase("custom error message 1", "UTF8_BINARY"), + RaiseErrorTestCase("custom error message 1", "UTF8_BINARY_RTRIM"), RaiseErrorTestCase("custom error message 2", "UTF8_LCASE"), + RaiseErrorTestCase("custom error message 2", "UTF8_LCASE_RTRIM"), RaiseErrorTestCase("custom error message 3", "UNICODE"), - RaiseErrorTestCase("custom error message 4", "UNICODE_CI") + RaiseErrorTestCase("custom error message 3", "UNICODE_RTRIM"), + RaiseErrorTestCase("custom error message 4", "UNICODE_CI"), + RaiseErrorTestCase("custom error message 4", "UNICODE_CI_RTRIM") ) testCases.foreach(t => { withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { @@ -1100,7 +1253,13 @@ class CollationSQLExpressionsSuite test("Support CurrentDatabase/Catalog/User expressions with collation") { // Supported collations - Seq("UTF8_LCASE", "UNICODE", "UNICODE_CI", "SR_CI_AI").foreach(collationName => + Seq( + "UTF8_LCASE", + "UTF8_LCASE_RTRIM", + "UNICODE", + "UNICODE_RTRIM", + "UNICODE_CI", + "SR_CI_AI").foreach(collationName => withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { val queryDatabase = sql("SELECT current_schema()") val queryCatalog = sql("SELECT current_catalog()") @@ -1116,7 +1275,14 @@ class CollationSQLExpressionsSuite test("Support Uuid misc expression with collation") { // Supported collations - Seq("UTF8_LCASE", "UNICODE", "UNICODE_CI", "NO_CI_AI").foreach(collationName => + Seq( + "UTF8_LCASE", + "UTF8_LCASE_RTRIM", + "UNICODE", + "UNICODE_RTRIM", + "UNICODE_CI", + "UNICODE_CI_RTRIM", + "NO_CI_AI").foreach(collationName => withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { val query = s"SELECT uuid()" // Result & data type @@ -1291,11 +1457,20 @@ class CollationSQLExpressionsSuite Row(1), Seq( StructField("a", IntegerType, nullable = true) )), + XmlToStructsTestCase("

1

", "UTF8_BINARY_RTRIM", "'a INT'", "", + Row(1), Seq( + StructField("a", IntegerType, nullable = true) + )), XmlToStructsTestCase("

true0.8

", "UTF8_LCASE", "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( StructField("A", BooleanType, nullable = true), StructField("B", DoubleType, nullable = true) )), + XmlToStructsTestCase("

true0.8

", "UTF8_LCASE_RTRIM", + "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( + StructField("A", BooleanType, nullable = true), + StructField("B", DoubleType, nullable = true) + )), XmlToStructsTestCase("

Spark

", "UNICODE", "'s STRING'", "", Row("Spark"), Seq( StructField("s", StringType, nullable = true) @@ -1304,6 +1479,11 @@ class CollationSQLExpressionsSuite Row("Spark"), Seq( StructField("s", StringType("UNICODE"), nullable = true) )), + XmlToStructsTestCase("

Spark

", "UNICODE_RTRIM", + "'s STRING COLLATE UNICODE_RTRIM'", "", + Row("Spark"), Seq( + StructField("s", StringType("UNICODE_RTRIM"), nullable = true) + )), XmlToStructsTestCase("

", "UNICODE_CI", "'time Timestamp'", ", map('timestampFormat', 'dd/MM/yyyy')", Row( new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") @@ -1337,10 +1517,16 @@ class CollationSQLExpressionsSuite val testCases = Seq( SchemaOfXmlTestCase("

1

", "UTF8_BINARY", "STRUCT"), + SchemaOfXmlTestCase("

1

", "UTF8_BINARY_RTRIM", "STRUCT"), SchemaOfXmlTestCase("

true0.8

", "UTF8_LCASE", "STRUCT"), + SchemaOfXmlTestCase("

true0.8

", "UTF8_LCASE_RTRIM", + "STRUCT"), SchemaOfXmlTestCase("

", "UNICODE", "STRUCT<>"), + SchemaOfXmlTestCase("

", "UNICODE_RTRIM", "STRUCT<>"), SchemaOfXmlTestCase("

123

", "UNICODE_CI", + "STRUCT>"), + SchemaOfXmlTestCase("

123

", "UNICODE_CI_RTRIM", "STRUCT>") ) @@ -1373,6 +1559,11 @@ class CollationSQLExpressionsSuite | 1 | 2 |""".stripMargin), + StructsToXmlTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY_RTRIM", + s""" + | 1 + | 2 + |""".stripMargin), StructsToXmlTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_LCASE", s""" | true @@ -1383,6 +1574,11 @@ class CollationSQLExpressionsSuite | aa | bb |""".stripMargin), + StructsToXmlTestCase("named_struct('A', 'aa', 'B', 'bb')", "UTF8_LCASE_RTRIM", + s""" + | aa + | bb + |""".stripMargin), StructsToXmlTestCase("named_struct('A', 'aa', 'B', 'bb')", "UTF8_BINARY", s""" | aa @@ -1390,6 +1586,8 @@ class CollationSQLExpressionsSuite |""".stripMargin), StructsToXmlTestCase("named_struct()", "UNICODE", ""), + StructsToXmlTestCase("named_struct()", "UNICODE_RTRIM", + ""), StructsToXmlTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI", s""" | @@ -1421,9 +1619,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( ParseJsonTestCase("{\"a\":1,\"b\":2}", "UTF8_BINARY", "{\"a\":1,\"b\":2}"), + ParseJsonTestCase("{\"a\":1,\"b\":2}", "UTF8_BINARY_RTRIM", "{\"a\":1,\"b\":2}"), ParseJsonTestCase("{\"A\":3,\"B\":4}", "UTF8_LCASE", "{\"A\":3,\"B\":4}"), + ParseJsonTestCase("{\"A\":3,\"B\":4}", "UTF8_LCASE_RTRIM", "{\"A\":3,\"B\":4}"), ParseJsonTestCase("{\"c\":5,\"d\":6}", "UNICODE", "{\"c\":5,\"d\":6}"), - ParseJsonTestCase("{\"C\":7,\"D\":8}", "UNICODE_CI", "{\"C\":7,\"D\":8}") + ParseJsonTestCase("{\"c\":5,\"d\":6}", "UNICODE_RTRIM", "{\"c\":5,\"d\":6}"), + ParseJsonTestCase("{\"C\":7,\"D\":8}", "UNICODE_CI", "{\"C\":7,\"D\":8}"), + ParseJsonTestCase("{\"C\":7,\"D\":8}", "UNICODE_CI_RTRIM", "{\"C\":7,\"D\":8}") ) // Supported collations (ParseJson) @@ -1493,9 +1695,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( IsVariantNullTestCase("'null'", "UTF8_BINARY", result = true), + IsVariantNullTestCase("'null'", "UTF8_BINARY_RTRIM", result = true), IsVariantNullTestCase("'\"null\"'", "UTF8_LCASE", result = false), + IsVariantNullTestCase("'\"null\"'", "UTF8_LCASE_RTRIM", result = false), IsVariantNullTestCase("'13'", "UNICODE", result = false), - IsVariantNullTestCase("null", "UNICODE_CI", result = false) + IsVariantNullTestCase("'13'", "UNICODE_RTRIM", result = false), + IsVariantNullTestCase("null", "UNICODE_CI", result = false), + IsVariantNullTestCase("null", "UNICODE_CI_RTRIM", result = false) ) // Supported collations @@ -1524,6 +1730,7 @@ class CollationSQLExpressionsSuite val testCases = Seq( VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY", 1, IntegerType), + VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY_RTRIM", 1, IntegerType), VariantGetTestCase("{\"a\": 1}", "$.b", "int", "UTF8_LCASE", null, IntegerType), VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", StringType), @@ -1610,6 +1817,14 @@ class CollationSQLExpressionsSuite StructField("value", VariantType, nullable = false) ) ), + VariantExplodeTestCase("[\"hello\", \"world\"]", "UTF8_BINARY_RTRIM", + Row(0, "null", "\"hello\"").toString() + Row(1, "null", "\"world\"").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UTF8_BINARY_RTRIM")), + StructField("value", VariantType, nullable = false) + ) + ), VariantExplodeTestCase("[\"Spark\", \"SQL\"]", "UTF8_LCASE", Row(0, "null", "\"Spark\"").toString() + Row(1, "null", "\"SQL\"").toString(), Seq[StructField]( @@ -1618,6 +1833,14 @@ class CollationSQLExpressionsSuite StructField("value", VariantType, nullable = false) ) ), + VariantExplodeTestCase("[\"Spark\", \"SQL\"]", "UTF8_LCASE_RTRIM", + Row(0, "null", "\"Spark\"").toString() + Row(1, "null", "\"SQL\"").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UTF8_LCASE_RTRIM")), + StructField("value", VariantType, nullable = false) + ) + ), VariantExplodeTestCase("{\"a\": true, \"b\": 3.14}", "UNICODE", Row(0, "a", "true").toString() + Row(1, "b", "3.14").toString(), Seq[StructField]( @@ -1626,6 +1849,14 @@ class CollationSQLExpressionsSuite StructField("value", VariantType, nullable = false) ) ), + VariantExplodeTestCase("{\"a\": true, \"b\": 3.14}", "UNICODE_RTRIM", + Row(0, "a", "true").toString() + Row(1, "b", "3.14").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UNICODE_RTRIM")), + StructField("value", VariantType, nullable = false) + ) + ), VariantExplodeTestCase("{\"A\": 9.99, \"B\": false}", "UNICODE_CI", Row(0, "A", "9.99").toString() + Row(1, "B", "false").toString(), Seq[StructField]( @@ -1661,11 +1892,17 @@ class CollationSQLExpressionsSuite val testCases = Seq( SchemaOfVariantTestCase("null", "UTF8_BINARY", "VOID"), + SchemaOfVariantTestCase("null", "UTF8_BINARY_RTRIM", "VOID"), SchemaOfVariantTestCase("[]", "UTF8_LCASE", "ARRAY"), + SchemaOfVariantTestCase("[]", "UTF8_LCASE_RTRIM", "ARRAY"), SchemaOfVariantTestCase("[{\"a\":true,\"b\":0}]", "UNICODE", "ARRAY>"), + SchemaOfVariantTestCase("[{\"a\":true,\"b\":0}]", "UNICODE_RTRIM", + "ARRAY>"), SchemaOfVariantTestCase("[{\"A\":\"x\",\"B\":-1.00}]", "UNICODE_CI", - "ARRAY>") + "ARRAY>"), + SchemaOfVariantTestCase("[{\"A\":\"x\",\"B\":-1.00}]", "UNICODE_CI_RTRIM", + "ARRAY>") ) // Supported collations @@ -1692,11 +1929,18 @@ class CollationSQLExpressionsSuite val testCases = Seq( SchemaOfVariantAggTestCase("('1'), ('2'), ('3')", "UTF8_BINARY", "BIGINT"), + SchemaOfVariantAggTestCase("('1'), ('2'), ('3')", "UTF8_BINARY_RTRIM", "BIGINT"), SchemaOfVariantAggTestCase("('true'), ('false'), ('true')", "UTF8_LCASE", "BOOLEAN"), + SchemaOfVariantAggTestCase("('true'), ('false'), ('true')", "UTF8_LCASE_RTRIM", "BOOLEAN"), SchemaOfVariantAggTestCase("('{\"a\": 1}'), ('{\"b\": true}'), ('{\"c\": 1.23}')", "UNICODE", "OBJECT"), + SchemaOfVariantAggTestCase("('{\"a\": 1}'), ('{\"b\": true}'), ('{\"c\": 1.23}')", + "UNICODE_RTRIM", "OBJECT"), + SchemaOfVariantAggTestCase("('{\"A\": \"x\"}'), ('{\"B\": 9.99}'), ('{\"C\": 0}')", + "UNICODE_CI", "OBJECT"), SchemaOfVariantAggTestCase("('{\"A\": \"x\"}'), ('{\"B\": 9.99}'), ('{\"C\": 0}')", - "UNICODE_CI", "OBJECT") + "UNICODE_CI_RTRIM", "OBJECT" + ) ) // Supported collations @@ -1716,7 +1960,16 @@ class CollationSQLExpressionsSuite test("Support InputFileName expression with collation") { // Supported collations - Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "MT_CI_AI").foreach(collationName => { + Seq( + "UTF8_BINARY", + "UTF8_BINARY_RTRIM", + "UTF8_LCASE", + "UTF8_LCASE_RTRIM", + "UNICODE", + "UNICODE_RTRIM", + "UNICODE_CI", + "UNICODE_CI_RTRIM", + "MT_CI_AI").foreach(collationName => { val query = s""" |select input_file_name() @@ -1735,9 +1988,13 @@ class CollationSQLExpressionsSuite case class DateFormatTestCase[R](date: String, format: String, collation: String, result: R) val testCases = Seq( DateFormatTestCase("2021-01-01", "yyyy-MM-dd", "UTF8_BINARY", "2021-01-01"), + DateFormatTestCase("2021-01-01", "yyyy-MM-dd", "UTF8_BINARY_RTRIM", "2021-01-01"), DateFormatTestCase("2021-01-01", "yyyy-dd", "UTF8_LCASE", "2021-01"), + DateFormatTestCase("2021-01-01", "yyyy-dd", "UTF8_LCASE_RTRIM", "2021-01"), DateFormatTestCase("2021-01-01", "yyyy-MM-dd", "UNICODE", "2021-01-01"), - DateFormatTestCase("2021-01-01", "yyyy", "UNICODE_CI", "2021") + DateFormatTestCase("2021-01-01", "yyyy-MM-dd", "UNICODE_RTRIM", "2021-01-01"), + DateFormatTestCase("2021-01-01", "yyyy", "UNICODE_CI", "2021"), + DateFormatTestCase("2021-01-01", "yyyy", "UNICODE_CI_RTRIM", "2021") ) for { @@ -1764,7 +2021,16 @@ class CollationSQLExpressionsSuite } test("Support mode for string expression with collation - Basic Test") { - Seq("utf8_binary", "UTF8_LCASE", "unicode_ci", "unicode", "NL_AI").foreach { collationId => + Seq( + "utf8_binary", + "utf8_binary_rtrim", + "UTF8_LCASE", + "UTF8_LCASE_RTRIM", + "unicode_ci", + "unicode_ci_rtrim", + "unicode", + "unicode_rtrim", + "NL_AI").foreach { collationId => val query = s"SELECT mode(collate('abc', '${collationId}'))" checkAnswer(sql(query), Row("abc")) assert(sql(query).schema.fields.head.dataType.sameType(StringType(collationId))) @@ -1775,9 +2041,13 @@ class CollationSQLExpressionsSuite case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("UTF8_LCASE_RTRIM", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("SR", Map("c" -> 3L, "č" -> 2L, "Č" -> 2L), "c") ) testCases.foreach(t => { @@ -1812,9 +2082,14 @@ class CollationSQLExpressionsSuite val testCasesUTF8String = Seq( UTF8StringModeTestCase("utf8_binary", bufferValuesUTF8String, "a"), + UTF8StringModeTestCase("utf8_binary_rtrim", bufferValuesUTF8String, "a"), UTF8StringModeTestCase("UTF8_LCASE", bufferValuesUTF8String, "b"), + UTF8StringModeTestCase("UTF8_LCASE_RTRIM", bufferValuesUTF8String, "b"), UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"), - UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a")) + UTF8StringModeTestCase("unicode_ci_rtrim", bufferValuesUTF8String, "b"), + UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"), + UTF8StringModeTestCase("unicode_rtrim", bufferValuesUTF8String, "a") + ) testCasesUTF8String.foreach ( t => { val buffer = new OpenHashMap[AnyRef, Long](5) @@ -1842,9 +2117,13 @@ class CollationSQLExpressionsSuite } val testCasesUTF8String = Seq( UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"), + UTF8StringModeTestCase("utf8_binary_rtrim", bufferValuesComplex, "[a,a,a]"), UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("UTF8_LCASE_rtrim", bufferValuesComplex, "[b,b,b]"), UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"), - UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]")) + UTF8StringModeTestCase("unicode_ci_rtrim", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]"), + UTF8StringModeTestCase("unicode_rtrim", bufferValuesComplex, "[a,a,a]")) testCasesUTF8String.foreach { t => val buffer = new OpenHashMap[AnyRef, Long](5) @@ -1862,9 +2141,13 @@ class CollationSQLExpressionsSuite case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("UTF8_LCASE_RTRIM", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), - ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) testCases.foreach(t => { val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -1887,9 +2170,13 @@ class CollationSQLExpressionsSuite case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("UTF8_LCASE_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), - ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -1912,9 +2199,13 @@ class CollationSQLExpressionsSuite case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("UTF8_LCASE_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), - ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -1938,9 +2229,13 @@ class CollationSQLExpressionsSuite case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("UTF8_LCASE_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), - ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -1964,9 +2259,13 @@ class CollationSQLExpressionsSuite case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("UTF8_LCASE_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), - ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -1991,8 +2290,11 @@ class CollationSQLExpressionsSuite case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("utf8_binary_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("unicode_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), + ModeTestCase("utf8_lcase_rtrim", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}") ).foreach { t1 => def getValuesToAdd(t: ModeTestCase): String = { @@ -2023,7 +2325,12 @@ class CollationSQLExpressionsSuite for { collateKey <- Seq(true, false) collateVal <- Seq(true, false) - defaultCollation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE") + defaultCollation <- Seq( + "UTF8_BINARY", + "UTF8_BINARY_RTRIM", + "UTF8_LCASE", + "UTF8_LCASE_RTRIM", + "UNICODE") } { val mapKey = if (collateKey) "'a' collate utf8_lcase" else "'a'" val mapVal = if (collateVal) "'b' collate utf8_lcase" else "'b'" @@ -2420,7 +2727,8 @@ class CollationSQLExpressionsSuite "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", true), ReflectExpressions("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", "A5Cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", false), - + ReflectExpressions("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", + "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary_rtrim", true), ReflectExpressions("A5cf6C42-0C85-418f-af6c-3E4E5b1328f2", "utf8_binary", "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_lcase", true), ReflectExpressions("A5cf6C42-0C85-418f-af6c-3E4E5b1328f2", "utf8_binary", @@ -3166,14 +3474,22 @@ class CollationSQLExpressionsSuite ) val testCases = Seq( - HyperLogLogPlusPlusTestCase("utf8_binary", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", - "aA", "Aa", "aa"), Seq(Row(10))), - HyperLogLogPlusPlusTestCase("utf8_lcase", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", - "aA", "Aa", "aa"), Seq(Row(5))), + HyperLogLogPlusPlusTestCase("utf8_binary", Seq("a", "a", "A", "z", "zz", "ZZ", "w", + "AA", "aA", "Aa", "aa"), Seq(Row(10))), + HyperLogLogPlusPlusTestCase("utf8_binary_rtrim", Seq("a ", "a", "a", "A", "z", "zz", "ZZ", + "w", "AA", "aA", "Aa", "aa"), Seq(Row(10))), + HyperLogLogPlusPlusTestCase("utf8_lcase", Seq("a", "a", "A", "z", "zz", "ZZ", "w", + "AA", "aA", "Aa", "aa"), Seq(Row(5))), + HyperLogLogPlusPlusTestCase("utf8_lcase_rtrim", Seq("a ", "a", "a", "A", "z", "zz", "ZZ", "w", + "AA", "aA", "Aa", "aa"), Seq(Row(5))), HyperLogLogPlusPlusTestCase("UNICODE", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", "aA", "Aa", "aa"), Seq(Row(9))), + HyperLogLogPlusPlusTestCase("UNICODE_RTRIM", Seq("a ", "a", "a", "A", "z", "zz", "ZZ", "w", + "AA", "aA", "Aa", "aa"), Seq(Row(9))), HyperLogLogPlusPlusTestCase("UNICODE_CI", Seq("a", "a", "A", "z", "zz", "ZZ", "w", "AA", - "aA", "Aa", "aa"), Seq(Row(5))) + "aA", "Aa", "aa"), Seq(Row(5))), + HyperLogLogPlusPlusTestCase("UNICODE_CI_RTRIM", Seq("a ", "a", "a", "A", "z", "zz", "ZZ", "w", + "AA", "aA", "Aa", "aa"), Seq(Row(5))) ) testCases.foreach( t => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 11f2c4b997a4b..a8fe36c9ba394 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -696,6 +696,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s"IN (COLLATE('aa', 'UTF8_LCASE'))"), Seq(Row("a"), Row("A"))) checkAnswer(sql(s"SELECT c1 FROM $tableName where (c1 || 'a') " + s"IN (COLLATE('aa', 'UTF8_BINARY'))"), Seq(Row("a"))) + checkAnswer(sql(s"SELECT c1 FROM $tableName where c1 || 'a' " + + s"IN (COLLATE('aa', 'UTF8_LCASE_RTRIM'))"), Seq(Row("a"), Row("A"))) + checkAnswer(sql(s"SELECT c1 FROM $tableName where (c1 || 'a') " + + s"IN (COLLATE('aa', 'UTF8_BINARY_RTRIM'))"), Seq(Row("a"))) + // columns have different collation checkError( @@ -806,6 +811,16 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) ) + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT array('A', 'a' COLLATE UNICODE) == array('b' COLLATE UNICODE_CI_RTRIM)") + }, + condition = "COLLATION_MISMATCH.EXPLICIT", + parameters = Map( + "explicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE UNICODE_CI_RTRIM"""" + ) + ) + checkError( exception = intercept[AnalysisException] { sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' collate UNICODE_CI)")