diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0878abbd0a843..4cf7d8efb96a5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -460,27 +460,6 @@ object DataType { // String types with possibly different collations are compatible. case (_: StringType, _: StringType) => true - case (ArrayType(fromElement, fromContainsNull), ArrayType(toElement, toContainsNull)) => - (fromContainsNull == toContainsNull) && - equalsIgnoreCompatibleCollation(fromElement, toElement) - - case ( - MapType(fromKey, fromValue, fromContainsNull), - MapType(toKey, toValue, toContainsNull)) => - fromContainsNull == toContainsNull && - // Map keys cannot change collation. - fromKey == toKey && - equalsIgnoreCompatibleCollation(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && - fromField.nullable == toField.nullable && - fromField.metadata == toField.metadata && - equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType) - } - case (fromDataType, toDataType) => fromDataType == toDataType } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 16899b656f304..724014273fed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1546,10 +1546,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB .map(dt => col.field.copy(dataType = dt)) .getOrElse(col.field) val newDataType = a.dataType.get - val sameTypeExceptCollations = - DataType.equalsIgnoreCompatibleCollation(field.dataType, newDataType) newDataType match { - case _ if sameTypeExceptCollations => // Allow changing type collations. case _: StructType => alter.failAnalysis( "CANNOT_UPDATE_FIELD.STRUCT_TYPE", Map("table" -> toSQLId(table.name), "fieldName" -> toSQLId(fieldName))) @@ -1576,10 +1573,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case (CharType(l1), CharType(l2)) => l1 == l2 case (CharType(l1), VarcharType(l2)) => l1 <= l2 case (VarcharType(l1), VarcharType(l2)) => l1 <= l2 - case _ => Cast.canUpCast(from, to) + case _ => + Cast.canUpCast(from, to) || + DataType.equalsIgnoreCompatibleCollation(field.dataType, newDataType) } - - if (!sameTypeExceptCollations && !canAlterColumnType(field.dataType, newDataType)) { + if (!canAlterColumnType(field.dataType, newDataType)) { alter.failAnalysis( errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", messageParameters = Map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 3552beb210a1b..d5fc4d87bb6ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -722,7 +722,7 @@ class DataTypeSuite extends SparkFunSuite with SQLHelper { checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), ArrayType(StringType("UTF8_LCASE")), - expected = true + expected = false ) checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), @@ -732,12 +732,12 @@ class DataTypeSuite extends SparkFunSuite with SQLHelper { checkEqualsIgnoreCompatibleCollation( ArrayType(ArrayType(StringType)), ArrayType(ArrayType(StringType("UTF8_LCASE"))), - expected = true + expected = false ) checkEqualsIgnoreCompatibleCollation( MapType(StringType, StringType), MapType(StringType, StringType("UTF8_LCASE")), - expected = true + expected = false ) checkEqualsIgnoreCompatibleCollation( MapType(StringType("UTF8_LCASE"), StringType), @@ -747,7 +747,7 @@ class DataTypeSuite extends SparkFunSuite with SQLHelper { checkEqualsIgnoreCompatibleCollation( MapType(StringType("UTF8_LCASE"), ArrayType(StringType)), MapType(StringType("UTF8_LCASE"), ArrayType(StringType("UTF8_LCASE"))), - expected = true + expected = false ) checkEqualsIgnoreCompatibleCollation( MapType(ArrayType(StringType), IntegerType), @@ -762,12 +762,12 @@ class DataTypeSuite extends SparkFunSuite with SQLHelper { checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", StringType) :: Nil), StructType(StructField("a", StringType("UTF8_LCASE")) :: Nil), - expected = true + expected = false ) checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", ArrayType(StringType)) :: Nil), StructType(StructField("a", ArrayType(StringType("UTF8_LCASE"))) :: Nil), - expected = true + expected = false ) checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 7d7c95a24ca69..9a47491b0cca4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAg import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} -import org.apache.spark.sql.types.{ArrayType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, Metadata, MetadataBuilder, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName @@ -529,15 +529,55 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |""".stripMargin) sql(s"INSERT INTO $tableName VALUES ('a', array('b'), map(1, 'c'), struct('d'))") sql(s"ALTER TABLE $tableName ALTER COLUMN c1 TYPE STRING COLLATE UTF8_LCASE") - sql(s"ALTER TABLE $tableName ALTER COLUMN c2 TYPE ARRAY") - sql(s"ALTER TABLE $tableName ALTER COLUMN c3 TYPE MAP") - sql(s"ALTER TABLE $tableName ALTER COLUMN c4 TYPE STRUCT") + sql(s"ALTER TABLE $tableName ALTER COLUMN c2.element TYPE STRING COLLATE UNICODE_CI") + sql(s"ALTER TABLE $tableName ALTER COLUMN c3.value TYPE STRING COLLATE UTF8_BINARY") + sql(s"ALTER TABLE $tableName ALTER COLUMN c4.t TYPE STRING COLLATE UNICODE") checkAnswer(sql(s"SELECT collation(c1), collation(c2[0]), " + s"collation(c3[1]), collation(c4.t) FROM $tableName"), Seq(Row("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY", "UNICODE"))) } } + test("SPARK-50262: Alter column with collation preserve metadata") { + def createMetadata(column: String): Metadata = + new MetadataBuilder().putString("key", column).build() + + val tableName = "testcat.alter_column_tbl" + withTable(tableName) { + val df = spark.createDataFrame( + java.util.List.of[Row](), + StructType(Seq( + StructField("c1", StringType, metadata = createMetadata("c1")), + StructField("c2", ArrayType(StringType), metadata = createMetadata("c2")), + StructField("c3", MapType(IntegerType, StringType), metadata = createMetadata("c3")), + StructField("c4", + StructType(Seq(StructField("t", StringType, metadata = createMetadata("c4t")))), + metadata = createMetadata("c4")) + )) + ) + df.write.format("parquet").saveAsTable(tableName) + + sql(s"INSERT INTO $tableName VALUES ('a', array('b'), map(1, 'c'), struct('d'))") + sql(s"ALTER TABLE $tableName ALTER COLUMN c1 TYPE STRING COLLATE UTF8_LCASE") + sql(s"ALTER TABLE $tableName ALTER COLUMN c2.element TYPE STRING COLLATE UNICODE_CI") + sql(s"ALTER TABLE $tableName ALTER COLUMN c3.value TYPE STRING COLLATE UTF8_BINARY") + sql(s"ALTER TABLE $tableName ALTER COLUMN c4.t TYPE STRING COLLATE UNICODE") + val testCatalog = catalog("testcat").asTableCatalog + val tableSchema = testCatalog.loadTable(Identifier.of(Array(), "alter_column_tbl")).schema() + val c1Metadata = tableSchema.find(_.name == "c1").get.metadata + assert(c1Metadata === createMetadata("c1")) + val c2Metadata = tableSchema.find(_.name == "c2").get.metadata + assert(c2Metadata === createMetadata("c2")) + val c3Metadata = tableSchema.find(_.name == "c3").get.metadata + assert(c3Metadata === createMetadata("c3")) + val c4Metadata = tableSchema.find(_.name == "c4").get.metadata + assert(c4Metadata === createMetadata("c4")) + val c4tMetadata = tableSchema.find(_.name == "c4").get.dataType + .asInstanceOf[StructType].find(_.name == "t").get.metadata + assert(c4tMetadata === createMetadata("c4t")) + } + } + test("SPARK-47210: Implicit casting of collated strings") { val tableName = "parquet_dummy_implicit_cast_t22" withTable(tableName) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e07f6406901e0..fec7183bc75e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2344,18 +2344,22 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { sql("CREATE TABLE t2(col ARRAY) USING parquet") sql("INSERT INTO t2 VALUES (ARRAY('a'))") checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_BINARY")) - sql("ALTER TABLE t2 ALTER COLUMN col TYPE ARRAY") - checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_LCASE")) + assertThrows[AnalysisException] { + sql("ALTER TABLE t2 ALTER COLUMN col TYPE ARRAY") + } + checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_BINARY")) // `MapType` with collation. sql("CREATE TABLE t3(col MAP) USING parquet") sql("INSERT INTO t3 VALUES (MAP('k', 'v'))") checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_BINARY")) - sql( - """ - |ALTER TABLE t3 ALTER COLUMN col TYPE - |MAP""".stripMargin) - checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_LCASE")) + assertThrows[AnalysisException] { + sql( + """ + |ALTER TABLE t3 ALTER COLUMN col TYPE + |MAP""".stripMargin) + } + checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_BINARY")) // Invalid change of map key collation. val alterMap = @@ -2367,7 +2371,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { }, condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( - "originType" -> "\"MAP\"", + "originType" -> "\"MAP\"", "originName" -> "`col`", "table" -> "`spark_catalog`.`default`.`t3`", "newType" -> "\"MAP\"", @@ -2380,8 +2384,10 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { sql("CREATE TABLE t4(col STRUCT) USING parquet") sql("INSERT INTO t4 VALUES (NAMED_STRUCT('a', 'value'))") checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_BINARY")) - sql("ALTER TABLE t4 ALTER COLUMN col TYPE STRUCT") - checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_LCASE")) + assertThrows[AnalysisException] { + sql("ALTER TABLE t4 ALTER COLUMN col TYPE STRUCT") + } + checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_BINARY")) } }