Skip to content

Commit

Permalink
[SPARK-47296][SQL][COLLATION] Fail unsupported functions for non-bina…
Browse files Browse the repository at this point in the history
…ry collations

### What changes were proposed in this pull request?

### Why are the changes needed?
Currently, all `StringType` arguments passed to built-in string functions in Spark SQL get treated as binary strings. This behaviour is incorrect for almost all collationIds except the default (0), and we should instead warn the user if they try to use an unsupported collation for the given function. Over time, we should implement the appropriate support for these (function, collation) pairs, but until then - we should have a way to fail unsupported statements in query analysis.

### Does this PR introduce _any_ user-facing change?
Yes, users will now get appropriate errors when they try to use an unsupported collation with a given string function.

### How was this patch tested?
Tests in CollationSuite to check if these functions work for binary collations and throw exceptions for others.

### Was this patch authored or co-authored using generative AI tooling?
Yes.

Closes apache#45422 from uros-db/regexp-functions.

Lead-authored-by: Uros Bojanic <[email protected]>
Co-authored-by: Mihailo Milosevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Mar 20, 2024
1 parent a3c04ec commit 8762e25
Show file tree
Hide file tree
Showing 9 changed files with 637 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
* equality and hashing).
*/
def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation
def isLowercaseCollation: Boolean = collationId == CollationFactory.LOWERCASE_COLLATION_ID

/**
* Type name that is shown to the customer.
Expand All @@ -54,8 +55,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa

override def hashCode(): Int = collationId.hashCode()

override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]

/**
* The default size of a value of the StringType is 20 bytes.
*/
Expand All @@ -65,6 +64,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
}

/**
* Use StringType for expressions supporting only binary collation.
*
* @since 1.3.0
*/
@Stable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ object AnsiTypeCoercion extends TypeCoercionBase {
case (NullType, target) if !target.isInstanceOf[TypeCollection] =>
Some(target.defaultConcreteType)

// If a function expects a StringType, no StringType instance should be implicitly cast to
// StringType with a collation that's not accepted (aka. lockdown unsupported collations).
case (_: StringType, StringType) => None
case (_: StringType, _: StringTypeCollated) => None

// This type coercion system will allow implicit converting String type as other
// primitive types, in case of breaking too many existing Spark SQL queries.
case (StringType, a: AtomicType) =>
Expand Down Expand Up @@ -215,6 +220,16 @@ object AnsiTypeCoercion extends TypeCoercionBase {
None
}

// "canANSIStoreAssign" doesn't account for targets extending StringTypeCollated, but
// ANSIStoreAssign is generally expected to work with StringTypes
case (_, st: StringTypeCollated) =>
if (Cast.canANSIStoreAssign(inType, st.defaultConcreteType)) {
Some(st.defaultConcreteType)
}
else {
None
}

// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,9 @@ object TypeCoercion extends TypeCoercionBase {
case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType
case (StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
case (any: AtomicType, StringType) if any != StringType => StringType
case (any: AtomicType, StringType) if !any.isInstanceOf[StringType] => StringType
case (any: AtomicType, st: StringTypeCollated)
if !any.isInstanceOf[StringType] => st.defaultConcreteType

// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

object CollationTypeConstraints {

def checkCollationCompatibility(collationId: Int, dataTypes: Seq[DataType]): TypeCheckResult = {
val collationName = CollationFactory.fetchCollation(collationId).collationName
// Additional check needed for collation compatibility
dataTypes.collectFirst {
case stringType: StringType if stringType.collationId != collationId =>
val collation = CollationFactory.fetchCollation(stringType.collationId)
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> collationName,
"collationNameRight" -> collation.collationName
)
)
} getOrElse TypeCheckResult.TypeCheckSuccess
}

}

/**
* StringTypeCollated is an abstract class for StringType with collation support.
*/
abstract class StringTypeCollated extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = StringType
}

/**
* Use StringTypeBinary for expressions supporting only binary collation.
*/
case object StringTypeBinary extends StringTypeCollated {
override private[sql] def simpleString: String = "string_binary"
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation
}

/**
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
*/
case object StringTypeBinaryLcase extends StringTypeCollated {
override private[sql] def simpleString: String = "string_binary_lcase"
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation ||
other.asInstanceOf[StringType].isLowercaseCollation)
}

/**
* Use StringTypeAnyCollation for expressions supporting all possible collation types.
*/
case object StringTypeAnyCollation extends StringTypeCollated {
override private[sql] def simpleString: String = "string_any_collation"
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class Collate(child: Expression, collationName: String)
extends UnaryExpression with ExpectsInputTypes {
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ trait String2StringExpression extends ImplicitCastInputTypes {

def convert(v: UTF8String): UTF8String

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType)
override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)

protected override def nullSafeEval(input: Any): Any =
convert(input.asInstanceOf[UTF8String])
Expand Down Expand Up @@ -501,26 +501,15 @@ abstract class StringPredicate extends BinaryExpression

def compare(l: UTF8String, r: UTF8String): Boolean

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def checkInputDataTypes(): TypeCheckResult = {
val checkResult = super.checkInputDataTypes()
if (checkResult.isFailure) {
return checkResult
}
// Additional check needed for collation compatibility
val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId
if (collationId != rightCollationId) {
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName,
"collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName
)
)
} else {
TypeCheckResult.TypeCheckSuccess
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
return defaultCheck
}
CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType))
}

protected override def nullSafeEval(input1: Any, input2: Any): Any =
Expand Down Expand Up @@ -1976,7 +1965,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
override def dataType: DataType = str.dataType

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType)

override def first: Expression = str
override def second: Expression = pos
Expand Down
Loading

0 comments on commit 8762e25

Please sign in to comment.