From b45045e56f81deab1aec829dfee310991f17a8bd Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Fri, 29 Nov 2024 10:45:33 +0800 Subject: [PATCH] [SPARK-49992][SQL] Default collation resolution for DDL and DML queries ### What changes were proposed in this pull request? This PR proposes not using session-level collation in DDL commands (create/alter view/table, add/replace columns). Also, resolution of default collation should happen in the analyzer and not in the parser. However, due to how we are checking for default string type (using reference equals with `StringType` object) we cannot just replace this object with `StringType("UTF8_BINARY")` because they compare as equal so the tree node framework will just return the old plan. Because of this we have to perform this resolution twice, once by changing the `StringType` object into a `TemporaryStringType` and then back to `StringType("UTF8_BINARY")` which is not considered a default string type anymore. Another thing is that the dependent rules `ResolveInlineTables` and `CollationTypeCoercion` are updated so that they don't execute if there are still unresolved string types in the plan. ### Why are the changes needed? The default collation for DDL commands should be tied to the object being created or altered (e.g., table, view, schema) rather than the session-level setting. Since object-level collations are not yet supported, we will assume the UTF8_BINARY collation by default for now. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48962 from stefankandic/fixSessionCollationOrder. Lead-authored-by: Stefan Kandic Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../catalyst/parser/DataTypeAstBuilder.scala | 2 +- .../internal/types/AbstractStringType.scala | 3 +- .../apache/spark/sql/types/ArrayType.scala | 9 + .../org/apache/spark/sql/types/DataType.scala | 7 + .../org/apache/spark/sql/types/MapType.scala | 12 + .../apache/spark/sql/types/StringType.scala | 11 +- .../apache/spark/sql/types/StructType.scala | 12 + .../sql/catalyst/analysis/Analyzer.scala | 1 + .../analysis/CollationTypeCoercion.scala | 14 +- .../analysis/ResolveDefaultStringTypes.scala | 188 +++++++ .../analysis/ResolveInlineTables.scala | 6 +- .../analysis/TypeCoercionHelper.scala | 4 +- .../spark/sql/catalyst/expressions/misc.scala | 4 +- .../expressions/stringExpressions.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 40 +- .../catalyst/plans/logical/v2Commands.scala | 6 + .../sql/catalyst/rules/RuleExecutor.scala | 31 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/types/DataTypeSuite.scala | 2 +- .../analysis/ResolveSessionCatalog.scala | 5 + .../spark/sql/execution/datasources/ddl.scala | 4 +- .../sql/CollationSQLExpressionsSuite.scala | 15 +- .../spark/sql/CollationSQLRegexpSuite.scala | 20 +- .../sql/CollationStringExpressionsSuite.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 23 - .../collation/DefaultCollationTestSuite.scala | 490 ++++++++++++++++++ 26 files changed, 855 insertions(+), 62 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 71e8517a4164e..c2cb4a7154076 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -76,7 +76,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => SqlApiConf.get.defaultStringType + case Seq(_) => StringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 49d8bf9e001ab..6dcb8a876b7a2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.internal.types -import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** @@ -26,7 +25,7 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} abstract class AbstractStringType(supportsTrimCollation: Boolean = false) extends AbstractDataType with Serializable { - override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType + override private[sql] def defaultConcreteType: DataType = StringType override private[sql] def simpleString: String = "string" override private[sql] def acceptsType(other: DataType): Boolean = other match { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index fc32248b4baf3..53dfc5e9b2828 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -110,4 +110,13 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + ArrayType(elementType.transformRecursively(f), containsNull) + } + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 036de22b4189a..12cfed5b58685 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -105,6 +105,13 @@ abstract class DataType extends AbstractDataType { */ private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + /** + * Recursively applies the provided partial function `f` to transform this DataType tree. + */ + private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) f(this) else this + } + final override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala index 1dfb9aaf9e29b..de656c13ca4bf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -89,6 +89,18 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + MapType( + keyType.transformRecursively(f), + valueType.transformRecursively(f), + valueContainsNull) + } + } } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1eb645e37c4aa..b2cf502f8bdc1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * The id of collation for this StringType. */ @Stable -class StringType private (val collationId: Int) extends AtomicType with Serializable { +class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if they are byte @@ -75,7 +75,14 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ */ override def typeName: String = if (isUTF8BinaryCollation) "string" - else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + else s"string collate $collationName" + + override def toString: String = + if (isUTF8BinaryCollation) "StringType" + else s"StringType($collationName)" + + private[sql] def collationName: String = + CollationFactory.fetchCollation(collationId).collationName // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 07f6b50bd4a7a..cc95d8ee94b02 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -502,6 +502,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + return f(this) + } + + val newFields = fields.map { field => + field.copy(dataType = field.dataType.transformRecursively(f)) + } + StructType(newFields) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84b3ca2289f4c..8e1b9da927c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: + ResolveDefaultStringTypes :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 532e5e0d0a066..cca1d21df3a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} import org.apache.spark.sql.util.SchemaUtils /** @@ -93,13 +93,6 @@ object CollationTypeCoercion { val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad)) stringPadExpr.withNewChildren(Seq(newStr, len, newPad)) - case raiseError: RaiseError => - val newErrorParams = raiseError.errorParms.dataType match { - case MapType(StringType, StringType, _) => raiseError.errorParms - case _ => Cast(raiseError.errorParms, MapType(StringType, StringType)) - } - raiseError.withNewChildren(Seq(raiseError.errorClass, newErrorParams)) - case framelessOffsetWindow @ (_: Lag | _: Lead) => val Seq(input, offset, default) = framelessOffsetWindow.children val Seq(newInput, newDefault) = collateToSingleType(Seq(input, default)) @@ -219,6 +212,11 @@ object CollationTypeCoercion { */ private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { + // if there are no collated types we don't need to do anything + return None + } else if (ResolveDefaultStringTypes.needsResolution(expressions)) { + // if any of the strings types are still not resolved + // we need to wait for them to be resolved first return None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala new file mode 100644 index 0000000000000..75958ff3e1177 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.types.{DataType, StringType} + +/** + * Resolves default string types in queries and commands. For queries, the default string type is + * determined by the session's default string type. For DDL, the default string type is the + * default type of the object (table -> schema -> catalog). However, this is not implemented yet. + * So, we will just use UTF8_BINARY for now. + */ +object ResolveDefaultStringTypes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val newPlan = apply0(plan) + if (plan.ne(newPlan)) { + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to transform the plan twice + // to ensure the correct results for occurrences of default string type. + val finalPlan = apply0(newPlan) + RuleExecutor.forceAdditionalIteration(finalPlan) + finalPlan + } else { + newPlan + } + } + + private def apply0(plan: LogicalPlan): LogicalPlan = { + if (isDDLCommand(plan)) { + transformDDL(plan) + } else { + transformPlan(plan, sessionDefaultStringType) + } + } + + /** + * Returns whether any of the given `plan` needs to have its + * default string type resolved. + */ + def needsResolution(plan: LogicalPlan): Boolean = { + if (!isDDLCommand(plan) && isDefaultSessionCollationUsed) { + return false + } + + plan.exists(node => needsResolution(node.expressions)) + } + + /** + * Returns whether any of the given `expressions` needs to have its + * default string type resolved. + */ + def needsResolution(expressions: Seq[Expression]): Boolean = { + expressions.exists(needsResolution) + } + + /** + * Returns whether the given `expression` needs to have its + * default string type resolved. + */ + def needsResolution(expression: Expression): Boolean = { + expression.exists(e => transformExpression.isDefinedAt(e)) + } + + private def isDefaultSessionCollationUsed: Boolean = conf.defaultStringType == StringType + + /** + * Returns the default string type that should be used in a given DDL command (for now always + * UTF8_BINARY). + */ + private def stringTypeForDDLCommand(table: LogicalPlan): StringType = + StringType("UTF8_BINARY") + + /** Returns the session default string type */ + private def sessionDefaultStringType: StringType = + StringType(conf.defaultStringType.collationId) + + private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { + case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _ => isCreateOrAlterPlan(plan) + } + + private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { + case _: V1CreateTablePlan | _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true + case _ => false + } + + private def transformDDL(plan: LogicalPlan): LogicalPlan = { + val newType = stringTypeForDDLCommand(plan) + + plan resolveOperators { + case p if isCreateOrAlterPlan(p) => + transformPlan(p, newType) + + case addCols: AddColumns => + addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType)) + + case replaceCols: ReplaceColumns => + replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType)) + + case alter: AlterColumn + if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => + alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType))) + } + } + + /** + * Transforms the given plan, by transforming all expressions in its operators to use the given + * new type instead of the default string type. + */ + private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { + plan resolveExpressionsUp { expression => + transformExpression + .andThen(_.apply(newType)) + .applyOrElse(expression, identity[Expression]) + } + } + + /** + * Transforms the given expression, by changing all default string types to the given new type. + */ + private def transformExpression: PartialFunction[Expression, StringType => Expression] = { + case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => + newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + + case cast: Cast if hasDefaultStringType(cast.dataType) => + newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + + case Literal(value, dt) if hasDefaultStringType(dt) => + newType => Literal(value, replaceDefaultStringType(dt, newType)) + } + + private def hasDefaultStringType(dataType: DataType): Boolean = + dataType.existsRecursively(isDefaultStringType) + + private def isDefaultStringType(dataType: DataType): Boolean = { + dataType match { + case st: StringType => + // should only return true for StringType object and not StringType("UTF8_BINARY") + st.eq(StringType) || st.isInstanceOf[TemporaryStringType] + case _ => false + } + } + + private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { + dataType.transformRecursively { + case currentType: StringType if isDefaultStringType(currentType) => + if (currentType == newType) { + TemporaryStringType() + } else { + newType + } + } + } + + private def replaceColumnTypes( + colTypes: Seq[QualifiedColType], + newType: StringType): Seq[QualifiedColType] = { + colTypes.map { + case colWithDefault if hasDefaultStringType(colWithDefault.dataType) => + val replaced = replaceDefaultStringType(colWithDefault.dataType, newType) + colWithDefault.copy(dataType = replaced) + + case col => col + } + } +} + +case class TemporaryStringType() extends StringType(1) { + override def toString: String = s"TemporaryStringType($collationId)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 62f3997491c07..b9e9e49a39647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,8 +29,12 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if table.expressionsResolved => + case table: UnresolvedInlineTable if canResolveTable(table) => EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table) } } + + private def canResolveTable(table: UnresolvedInlineTable): Boolean = { + table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala index 5b4d76a2a73ed..3fc4b71c986ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala @@ -415,7 +415,7 @@ abstract class TypeCoercionHelper { if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } c.copy(children = newChildren) case other => other @@ -465,7 +465,7 @@ abstract class TypeCoercionHelper { if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { children.tail.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } } else { children.tail diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5f1b3dc0a01ac..622a0e0aa5bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)) override def left: Expression = errorClass override def right: Expression = errorParms diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c97920619ba4d..2ea53350fea36 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1863,7 +1863,7 @@ trait PadExpressionBuilderBase extends ExpressionBuilder { BinaryPad(funcName, expressions(0), expressions(1), Literal(Array[Byte](0))) } else { createStringPad(expressions(0), - expressions(1), Literal.create(" ", SQLConf.get.defaultStringType)) + expressions(1), Literal(" ")) } } else if (numArgs == 3) { if (expressions(0).dataType == BinaryType && expressions(2).dataType == BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d558689a5c196..3d74e9d314d57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2150,7 +2150,7 @@ class AstBuilder extends DataTypeAstBuilder } val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + val table = if (canEagerlyEvaluateInlineTable(ctx, unresolvedTable)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2158,6 +2158,42 @@ class AstBuilder extends DataTypeAstBuilder table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } + /** + * Determines if the inline table can be eagerly evaluated. + */ + private def canEagerlyEvaluateInlineTable( + ctx: InlineTableContext, + table: UnresolvedInlineTable): Boolean = { + if (!conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + return false + } else if (!ResolveDefaultStringTypes.needsResolution(table.expressions)) { + // if there are no strings to be resolved we can always evaluate eagerly + return true + } + + val isSessionCollationSet = conf.defaultStringType != StringType + + // if either of these are true we need to resolve + // the string types first + !isSessionCollationSet && !contextInsideCreate(ctx) + } + + private def contextInsideCreate(ctx: ParserRuleContext): Boolean = { + var currentContext: RuleContext = ctx + + while (currentContext != null) { + if (currentContext.isInstanceOf[CreateTableContext] || + currentContext.isInstanceOf[ReplaceTableContext] || + currentContext.isInstanceOf[CreateViewContext]) { + return true + } + + currentContext = currentContext.parent + } + + false + } + /** * Create an alias (SubqueryAlias) for a join relation. This is practically the same as * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different @@ -3369,7 +3405,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), conf.defaultStringType) + Literal.create(createString(ctx), StringType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b465e0e11612f..857522728eaff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,12 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +/** + * A trait used for logical plan nodes that create V1 table definitions, + * and so that rules from the catalyst module can identify them. + */ +trait V1CreateTablePlan extends LogicalPlan + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 76d36fab2096a..bdbf698db2e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,7 +22,8 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.rules.RuleExecutor.getForceIterationValue +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors @@ -30,6 +31,27 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { + + /** + * A tag used to explicitly request an additional iteration of the current batch during + * rule execution, even if the query plan remains unchanged. Increment the tag's value + * to enforce another iteration. + */ + private val FORCE_ADDITIONAL_ITERATION = TreeNodeTag[Int]("forceAdditionalIteration") + + /** + * Increments the value of the FORCE_ADDITIONAL_ITERATION tag on the given plan to + * explicitly force another iteration of the current batch during rule execution. + */ + def forceAdditionalIteration(plan: TreeNode[_]): Unit = { + val oldValue = getForceIterationValue(plan) + plan.setTagValue(FORCE_ADDITIONAL_ITERATION, oldValue + 1) + } + + private def getForceIterationValue(plan: TreeNode[_]): Int = { + plan.getTagValue(FORCE_ADDITIONAL_ITERATION).getOrElse(0) + } + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ @@ -303,7 +325,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan)) { + if (isFixedPointReached(lastPlan, curPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false @@ -317,4 +339,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + private def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { + oldPlan.fastEquals(newPlan) && + getForceIterationValue(newPlan) <= getForceIterationValue(oldPlan) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 378eca09097f5..e8031580c1165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5580,7 +5580,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { StringType } else { - StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION))) + StringType(getConf(DEFAULT_COLLATION)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f6d8f2a66e202..7250b6e2b90e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -715,7 +715,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_LCASE"), expected = true) checkEqualsIgnoreCompatibleCollation( - StringType("UTF8_BINARY"), StringType("UTF8_LCASE"), expected = true) + StringType("UTF8_LCASE"), StringType("UTF8_BINARY"), expected = true) // Complex types. checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 92c74f7bede18..5f1ab089cf3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -54,6 +54,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case _ if ResolveDefaultStringTypes.needsResolution(plan) => + // if there are still unresolved string types in the plan + // we should not try to resolve it + plan + case AddColumns(ResolvedV1TableIdentifier(ident), cols) => cols.foreach { c => if (c.name.length > 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d9367d92d462e..eb9d5813cff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan { + query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 6feb4587b816f..cf494fcd87451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -376,6 +376,10 @@ class CollationSQLExpressionsSuite StructField("B", DoubleType, nullable = true) )), CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "", + Row("Spark"), Seq( + StructField("a", StringType, nullable = true) + )), + CsvToStructsTestCase("\"Spark\"", "UTF8_BINARY", "'a STRING COLLATE UNICODE'", "", Row("Spark"), Seq( StructField("a", StringType("UNICODE"), nullable = true) )), @@ -1291,6 +1295,10 @@ class CollationSQLExpressionsSuite StructField("B", DoubleType, nullable = true) )), XmlToStructsTestCase("

Spark

", "UNICODE", "'s STRING'", "", + Row("Spark"), Seq( + StructField("s", StringType, nullable = true) + )), + XmlToStructsTestCase("

Spark

", "UTF8_BINARY", "'s STRING COLLATE UNICODE'", "", Row("Spark"), Seq( StructField("s", StringType("UNICODE"), nullable = true) )), @@ -1515,8 +1523,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY", 1, IntegerType), VariantGetTestCase("{\"a\": 1}", "$.b", "int", "UTF8_LCASE", null, IntegerType), - VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", StringType("UNICODE")), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", + StringType), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string collate unicode", "UTF8_BINARY", "2", + StringType("UNICODE")), VariantGetTestCase("[1, \"2\"]", "$[2]", "string", "UNICODE_CI", null, + StringType), + VariantGetTestCase("[1, \"2\"]", "$[2]", "string collate unicode_CI", "UTF8_BINARY", null, StringType("UNICODE_CI")) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 5bb8511d0d935..7cafb999ffcf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -111,15 +111,17 @@ class CollationSQLRegexpSuite } val tableNameLcase = "T_LCASE" withTable(tableNameLcase) { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { - sql(s"CREATE TABLE IF NOT EXISTS $tableNameLcase(c STRING) using PARQUET") - sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") - checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) - } + sql(s""" + |CREATE TABLE IF NOT EXISTS $tableNameLcase( + | c STRING COLLATE UTF8_LCASE + |) using PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") + checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 2a0b84c075079..626bd0b239366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -198,8 +198,8 @@ class CollationStringExpressionsSuite checkError( exception = intercept[AnalysisException] { val expr = StringSplitSQL( - Collate(Literal.create("1a2"), "UTF8_BINARY"), - Collate(Literal.create("a"), "UTF8_LCASE")) + Collate(Literal.create("1a2", StringType("UTF8_BINARY")), "UTF8_BINARY"), + Collate(Literal.create("a", StringType("UTF8_BINARY")), "UTF8_LCASE")) CollationTypeCasts.transform(expr) }, condition = "COLLATION_MISMATCH.EXPLICIT", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 1707820053837..f0f81e713457b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1096,29 +1096,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("SPARK-47431: Default collation set to UNICODE, column type test") { - withTable("t") { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - sql(s"CREATE TABLE t(c1 STRING) USING PARQUET") - sql(s"INSERT INTO t VALUES ('a')") - checkAnswer(sql(s"SELECT collation(c1) FROM t"), Seq(Row("UNICODE"))) - } - } - } - - test("SPARK-47431: Create table with UTF8_BINARY, make sure collation persists on read") { - withTable("t") { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY") { - sql("CREATE TABLE t(c1 STRING) USING PARQUET") - sql("INSERT INTO t VALUES ('a')") - checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY"))) - } - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY"))) - } - } - } - test("Create dataframe with non utf8 binary collation") { val schema = StructType(Seq(StructField("Name", StringType("UNICODE_CI")))) val data = Seq(Row("Alice"), Row("Bob"), Row("bob")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala new file mode 100644 index 0000000000000..0de638d4e9bf9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.collation + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.connector.DatasourceV2SQLBase +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSession { + + def dataSource: String = "parquet" + def testTable: String = "test_tbl" + def testView: String = "test_view" + + def withSessionCollationAndTable(collation: String, testTables: String*)(f: => Unit): Unit = { + withTable(testTables: _*) { + withSessionCollation(collation) { + f + } + } + } + + def withSessionCollationAndView(collation: String, viewNames: String*)(f: => Unit): Unit = { + withView(viewNames: _*) { + withSessionCollation(collation) { + f + } + } + } + + def withSessionCollation(collation: String)(f: => Unit): Unit = { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { + f + } + } + + def assertTableColumnCollation( + table: String, + column: String, + expectedCollation: String): Unit = { + val colType = spark.table(table).schema(column).dataType + assert(colType === StringType(expectedCollation)) + } + + def assertThrowsImplicitMismatch(f: => DataFrame): Unit = { + val exception = intercept[AnalysisException] { + f + } + assert(exception.getCondition === "COLLATION_MISMATCH.IMPLICIT") + } + + // region DDL tests + + test("create/alter table") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + // create table with implicit collation + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + + // alter table add column with implicit collation + sql(s"ALTER TABLE $testTable ADD COLUMN c2 STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + + sql(s"ALTER TABLE $testTable ALTER COLUMN c2 TYPE STRING COLLATE UNICODE") + assertTableColumnCollation(testTable, "c2", "UNICODE") + + sql(s"ALTER TABLE $testTable ALTER COLUMN c2 TYPE STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + } + } + + test("create table with explicit collation") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_LCASE") + } + + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UNICODE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UNICODE") + } + } + + test("create table as select") { + // literals in select do not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS SELECT + | 'a' AS c1, + | 'a' || 'a' AS c2, + | SUBSTRING('a', 1, 1) AS c3, + | SUBSTRING(SUBSTRING('ab', 1, 1), 1, 1) AS c4, + | 'a' = 'A' AS truthy + |""".stripMargin) + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c3", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c4", "UTF8_BINARY") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE truthy"), Seq(Row(0))) + } + + // literals in inline table do not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) + |""".stripMargin) + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(1))) + } + + // cast in select does not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable USING $dataSource AS SELECT cast('a' AS STRING) AS c1") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + } + } + + test("ctas with complex types") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT + | struct('a') AS c1, + | map('a', 'b') AS c2, + | array('a') AS c3 + |""".stripMargin) + + checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c2['a']) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + } + } + + test("ctas with union") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row(false))) + } + + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION ALL + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row(false), Row(false))) + } + } + + test("add column") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_LCASE") + + sql(s"ALTER TABLE $testTable ADD COLUMN c2 STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + + sql(s"ALTER TABLE $testTable ADD COLUMN c3 STRING COLLATE UNICODE") + assertTableColumnCollation(testTable, "c3", "UNICODE") + } + } + + test("inline table in CTAS") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource + |AS SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |AS inline_table(c1, c2); + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } + + test("subsequent analyzer iterations correctly resolve default string types") { + // since concat coercion happens after resolving default types this test + // makes sure that we are correctly resolving the default string types + // in subsequent analyzer iterations + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource AS + |SELECT CONCAT(X'68656C6C6F', 'world') AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT c1 FROM $testTable"), Seq(Row("helloworld"))) + } + + // ELT is similar + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource AS + |SELECT ELT(1, X'68656C6C6F', 'world') AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT c1 FROM $testTable"), Seq(Row("hello"))) + } + } + + // endregion + + // region DML tests + + test("literals with default collation") { + val sessionCollation = "UTF8_LCASE" + withSessionCollation(sessionCollation) { + + // literal without collation + checkAnswer(sql("SELECT COLLATION('a')"), Seq(Row(sessionCollation))) + + checkAnswer(sql("SELECT COLLATION(map('a', 'b')['a'])"), Seq(Row(sessionCollation))) + + checkAnswer(sql("SELECT COLLATION(array('a')[0])"), Seq(Row(sessionCollation))) + + checkAnswer(sql("SELECT COLLATION(struct('a' as c)['c'])"), Seq(Row(sessionCollation))) + } + } + + test("literals with explicit collation") { + withSessionCollation("UTF8_LCASE") { + checkAnswer(sql("SELECT COLLATION('a' collate unicode)"), Seq(Row("UNICODE"))) + + checkAnswer( + sql("SELECT COLLATION(map('a', 'b' collate unicode)['a'])"), + Seq(Row("UNICODE"))) + + checkAnswer(sql("SELECT COLLATION(array('a' collate unicode)[0])"), Seq(Row("UNICODE"))) + + checkAnswer( + sql("SELECT COLLATION(struct('a' collate unicode as c)['c'])"), + Seq(Row("UNICODE"))) + } + } + + test("cast is aware of session collation") { + val sessionCollation = "UTF8_LCASE" + withSessionCollation(sessionCollation) { + checkAnswer(sql("SELECT COLLATION(cast('a' as STRING))"), Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(cast(map('a', 'b') as MAP)['a'])"), + Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(map_keys(cast(map('a', 'b') as MAP))[0])"), + Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(cast(array('a') as ARRAY)[0])"), + Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(cast(struct('a' as c) as STRUCT)['c'])"), + Seq(Row(sessionCollation))) + } + } + + test("expressions in where are aware of session collation") { + withSessionCollation("UTF8_LCASE") { + // expression in where is aware of session collation + checkAnswer(sql("SELECT 1 WHERE 'a' = 'A'"), Seq(Row(1))) + + checkAnswer(sql("SELECT 1 WHERE 'a' = cast('A' as STRING)"), Seq(Row(1))) + } + } + + test("having group by is aware of session collation") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a'), ('A')") + + // having clause uses session (default) collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testTable GROUP BY c1 HAVING 'a' = 'A'"), + Seq(Row(1), Row(1))) + + // having clause uses column (implicit) collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testTable GROUP BY c1 HAVING c1 = 'A'"), + Seq(Row(1))) + } + } + + test("min/max are aware of session collation") { + // scalastyle:off nonascii + withSessionCollationAndTable("UNICODE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('1'), ('½')") + + checkAnswer(sql(s"SELECT MIN(c1) FROM $testTable"), Seq(Row("1"))) + + checkAnswer(sql(s"SELECT MAX(c1) FROM $testTable"), Seq(Row("½"))) + } + // scalastyle:on nonascii + } + + test("union operation with subqueries") { + withSessionCollation("UTF8_LCASE") { + checkAnswer( + sql(s""" + |SELECT 'a' = 'A' + |UNION + |SELECT 'b' = 'B' + |""".stripMargin), + Seq(Row(true))) + + checkAnswer( + sql(s""" + |SELECT 'a' = 'A' + |UNION ALL + |SELECT 'b' = 'B' + |""".stripMargin), + Seq(Row(true), Row(true))) + } + } + + test("inline table in SELECT") { + withSessionCollation("UTF8_LCASE") { + val df = s""" + |SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |""".stripMargin + + checkAnswer(sql(df), Seq(Row("a", true))) + } + } + + test("inline table in insert") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") + + sql(s"INSERT INTO $testTable VALUES ('a', 'a' = 'A')") + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row("a", true))) + } + } + + test("literals in insert inherit session level collation") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 BOOLEAN) USING $dataSource") + + sql(s"INSERT INTO $testTable VALUES ('a' = 'A')") + sql(s"INSERT INTO $testTable VALUES (array_contains(array('a'), 'A'))") + sql(s"INSERT INTO $testTable VALUES (CONCAT(X'68656C6C6F', 'world') = 'HELLOWORLD')") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(3))) + } + } + + // endregion +} + +class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite { + + test("create/alter view created from a table") { + val sessionCollation = "UTF8_LCASE" + withSessionCollationAndTable(sessionCollation, testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a', 'a'), ('A', 'A')") + + withView(testView) { + sql(s"CREATE VIEW $testView AS SELECT * FROM $testTable") + + assertTableColumnCollation(testView, "c1", "UTF8_BINARY") + assertTableColumnCollation(testView, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $testView"), + Row("UTF8_BINARY", sessionCollation)) + + // filter should use session collation + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE 'a' = 'A'"), Row(2)) + + // filter should use column collation + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(1)) + + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = substring('A', 0, 1)"), + Row(1)) + + // literal with explicit collation wins + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A' collate UNICODE_CI"), + Row(2)) + + // two implicit collations -> errors out + assertThrowsImplicitMismatch(sql(s"SELECT c1 = c2 FROM $testView")) + + sql(s"ALTER VIEW $testView AS SELECT c1 COLLATE UNICODE_CI AS c1, c2 FROM $testTable") + assertTableColumnCollation(testView, "c1", "UNICODE_CI") + assertTableColumnCollation(testView, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $testView"), + Row("UNICODE_CI", sessionCollation)) + + // after alter both rows should be returned + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(2)) + } + } + } + + test("join view with table") { + val viewTableName = "view_table" + val joinTableName = "join_table" + val sessionCollation = "sr" + + withSessionCollationAndTable(sessionCollation, viewTableName, joinTableName) { + sql(s"CREATE TABLE $viewTableName (c1 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"CREATE TABLE $joinTableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + sql(s"INSERT INTO $viewTableName VALUES ('a')") + sql(s"INSERT INTO $joinTableName VALUES ('A')") + + withView(testView) { + sql(s"CREATE VIEW $testView AS SELECT * FROM $viewTableName") + + assertThrowsImplicitMismatch( + sql(s"SELECT * FROM $testView JOIN $joinTableName ON $testView.c1 = $joinTableName.c1")) + + checkAnswer( + sql(s""" + |SELECT COLLATION($testView.c1), COLLATION($joinTableName.c1) + |FROM $testView JOIN $joinTableName + |ON $testView.c1 = $joinTableName.c1 COLLATE UNICODE_CI + |""".stripMargin), + Row("UNICODE_CI", "UTF8_LCASE")) + } + } + } +} + +class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with DatasourceV2SQLBase { + override def testTable: String = s"testcat.${super.testTable}" + override def testView: String = s"testcat.${super.testView}" + + // delete only works on v2 + test("delete behavior") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a'), ('A')") + + sql(s"DELETE FROM $testTable WHERE 'a' = 'A'") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable"), Seq(Row(0))) + } + } + + test("inline table in RTAS") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") + sql(s""" + |REPLACE TABLE $testTable + |USING $dataSource + |AS SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |AS inline_table(c1, c2); + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } +}