Skip to content

Commit

Permalink
[SPARK-50060][SQL] Disabled conversion between different collated typ…
Browse files Browse the repository at this point in the history
…es in TypeCoercion and AnsiTypeCoercion

### What changes were proposed in this pull request?

In this PR, I propose disabling casting between different collations in set operators(`union [all/distinct]`, `except [all/distinct]` and `intersect [all/distinct]`).

### Why are the changes needed?

These changes are needed to ensure the correct behavior of set operators with collated strings. This way, the user will get appropriate error in case of collation collision.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Tests were added in `CollationSQLExpressionsSuite` and in `collations.sql` golden file.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48585 from vladanvasi-db/vladanvasi-db/set-operators-collations-behavior-fix.

Authored-by: Vladan Vasić <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
vladanvasi-db authored and cloud-fan committed Oct 25, 2024
1 parent 3da8f70 commit 5a43c97
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ object AnsiTypeCoercion extends TypeCoercionBase {
// interval type the string should be promoted as. There are many possible interval
// types, such as year interval, month interval, day interval, hour interval, etc.
case (_: StringType, _: AnsiIntervalType) => None
// [SPARK-50060] If a binary operation contains two collated string types with different
// collation IDs, we can't decide which collation ID the result should have.
case (st1: StringType, st2: StringType) if st1.collationId != st2.collationId => None
case (_: StringType, a: AtomicType) => Some(a)
case (other, st: StringType) if !other.isInstanceOf[StringType] =>
findWiderTypeForString(st, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType,
StringTypeWithCollation}
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence

Expand Down Expand Up @@ -905,6 +904,9 @@ object TypeCoercion extends TypeCoercionBase {

/** Promotes all the way to StringType. */
private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
// [SPARK-50060] If a binary operation contains two collated string types with different
// collation IDs, we can't decide which collation ID the result should have.
case (st1: StringType, st2: StringType) if st1.collationId != st2.collationId => None
case (st: StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(st)
case (t1: AtomicType, st: StringType) if t1 != BinaryType && t1 != BooleanType => Some(st)
case _ => None
Expand Down Expand Up @@ -1014,7 +1016,7 @@ object TypeCoercion extends TypeCoercionBase {
case (_: StringType, datetime: DatetimeType) => datetime
case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType
case (_: StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
// Cast any atomic type to string except if there are strings with different collations.
case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st
case (any: AtomicType, st: AbstractStringType)
if !any.isInstanceOf[StringType] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,131 @@ Intersect false
+- LocalRelation [col1#x]


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "EXCEPT",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 162,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except all select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "EXCEPT ALL",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 166,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except all select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "UNION",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 161,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union all select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "UNION",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 165,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union all select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') intersect select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "INTERSECT",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 156,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') intersect select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
create table t1 (c1 struct<utf8_binary: string collate utf8_binary, utf8_lcase: string collate utf8_lcase>) USING PARQUET
-- !query analysis
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/collations.sql
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), (
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union all select col1 collate utf8_lcase from values ('aaa'), ('bbb');
select col1 collate utf8_lcase from values ('aaa'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') intersect select col1 collate utf8_lcase from values ('aaa'), ('bbb');

-- set operations with conflicting collations
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except select col1 collate unicode_ci from values ('aaa'), ('bbb');
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except all select col1 collate unicode_ci from values ('aaa'), ('bbb');
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union select col1 collate unicode_ci from values ('aaa'), ('bbb');
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union all select col1 collate unicode_ci from values ('aaa'), ('bbb');
select col1 collate utf8_lcase from values ('aaa'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') intersect select col1 collate unicode_ci from values ('aaa'), ('bbb');

-- create table with struct field
create table t1 (c1 struct<utf8_binary: string collate utf8_binary, utf8_lcase: string collate utf8_lcase>) USING PARQUET;

Expand Down
135 changes: 135 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/collations.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,141 @@ aaa
bbb


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "EXCEPT",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 162,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except all select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "EXCEPT ALL",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 166,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') except all select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "UNION",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 161,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union all select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "UNION",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 165,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('AAA'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') union all select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
select col1 collate utf8_lcase from values ('aaa'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') intersect select col1 collate unicode_ci from values ('aaa'), ('bbb')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INCOMPATIBLE_COLUMN_TYPE",
"sqlState" : "42825",
"messageParameters" : {
"columnOrdinalNumber" : "first",
"dataType1" : "\"STRING COLLATE UNICODE_CI\"",
"dataType2" : "\"STRING COLLATE UTF8_LCASE\"",
"hint" : "",
"operator" : "INTERSECT",
"tableOrdinalNumber" : "second"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 156,
"fragment" : "select col1 collate utf8_lcase from values ('aaa'), ('bbb'), ('BBB'), ('zzz'), ('ZZZ') intersect select col1 collate unicode_ci from values ('aaa'), ('bbb')"
} ]
}


-- !query
create table t1 (c1 struct<utf8_binary: string collate utf8_binary, utf8_lcase: string collate utf8_lcase>) USING PARQUET
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3260,6 +3260,59 @@ class CollationSQLExpressionsSuite
}
}

test("SPARK-50060: set operators with conflicting collations") {
val setOperators = Seq[(String, Int, Int)](
("UNION", 64, 45),
("INTERSECT", 68, 49),
("EXCEPT", 65, 46))

for {
ansiEnabled <- Seq(true, false)
(operator, stopExplicit, stopDefault) <- setOperators
} {
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
SqlApiConf.DEFAULT_COLLATION -> "UNICODE_CI") {
val explicitConflictQuery =
s"SELECT 'a' COLLATE UTF8_LCASE $operator SELECT 'A' COLLATE UNICODE_CI"
checkError(
exception = intercept[AnalysisException] {
sql(explicitConflictQuery)
},
condition = "INCOMPATIBLE_COLUMN_TYPE",
parameters = Map(
"columnOrdinalNumber" -> "first",
"tableOrdinalNumber" -> "second",
"dataType1" -> "\"STRING COLLATE UNICODE_CI\"",
"dataType2" -> "\"STRING COLLATE UTF8_LCASE\"",
"operator" -> operator,
"hint" -> ""),
context = ExpectedContext(
fragment = explicitConflictQuery,
start = 0,
stop = stopExplicit))

val defaultConflictQuery =
s"SELECT 'a' COLLATE UTF8_LCASE $operator SELECT 'A'"
checkError(
exception = intercept[AnalysisException] {
sql(defaultConflictQuery)
},
condition = "INCOMPATIBLE_COLUMN_TYPE",
parameters = Map(
"columnOrdinalNumber" -> "first",
"tableOrdinalNumber" -> "second",
"dataType1" -> "\"STRING COLLATE UNICODE_CI\"",
"dataType2" -> "\"STRING COLLATE UTF8_LCASE\"",
"operator" -> operator,
"hint" -> ""),
context = ExpectedContext(
fragment = defaultConflictQuery,
start = 0,
stop = stopDefault))
}
}
}

test("Support HyperLogLogPlusPlus expression with collation") {
case class HyperLogLogPlusPlusTestCase(
collation: String,
Expand Down

0 comments on commit 5a43c97

Please sign in to comment.