diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 71e8517a4164e..c2cb4a7154076 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -76,7 +76,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => SqlApiConf.get.defaultStringType + case Seq(_) => StringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 49d8bf9e001ab..6dcb8a876b7a2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.internal.types -import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** @@ -26,7 +25,7 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} abstract class AbstractStringType(supportsTrimCollation: Boolean = false) extends AbstractDataType with Serializable { - override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType + override private[sql] def defaultConcreteType: DataType = StringType override private[sql] def simpleString: String = "string" override private[sql] def acceptsType(other: DataType): Boolean = other match { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index fc32248b4baf3..53dfc5e9b2828 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -110,4 +110,13 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + ArrayType(elementType.transformRecursively(f), containsNull) + } + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 036de22b4189a..12cfed5b58685 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -105,6 +105,13 @@ abstract class DataType extends AbstractDataType { */ private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + /** + * Recursively applies the provided partial function `f` to transform this DataType tree. + */ + private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) f(this) else this + } + final override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala index 1dfb9aaf9e29b..de656c13ca4bf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -89,6 +89,18 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + MapType( + keyType.transformRecursively(f), + valueType.transformRecursively(f), + valueContainsNull) + } + } } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1eb645e37c4aa..b2cf502f8bdc1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * The id of collation for this StringType. */ @Stable -class StringType private (val collationId: Int) extends AtomicType with Serializable { +class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if they are byte @@ -75,7 +75,14 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ */ override def typeName: String = if (isUTF8BinaryCollation) "string" - else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + else s"string collate $collationName" + + override def toString: String = + if (isUTF8BinaryCollation) "StringType" + else s"StringType($collationName)" + + private[sql] def collationName: String = + CollationFactory.fetchCollation(collationId).collationName // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 07f6b50bd4a7a..cc95d8ee94b02 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -502,6 +502,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + return f(this) + } + + val newFields = fields.map { field => + field.copy(dataType = field.dataType.transformRecursively(f)) + } + StructType(newFields) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84b3ca2289f4c..8e1b9da927c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: + ResolveDefaultStringTypes :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 532e5e0d0a066..cca1d21df3a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} import org.apache.spark.sql.util.SchemaUtils /** @@ -93,13 +93,6 @@ object CollationTypeCoercion { val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad)) stringPadExpr.withNewChildren(Seq(newStr, len, newPad)) - case raiseError: RaiseError => - val newErrorParams = raiseError.errorParms.dataType match { - case MapType(StringType, StringType, _) => raiseError.errorParms - case _ => Cast(raiseError.errorParms, MapType(StringType, StringType)) - } - raiseError.withNewChildren(Seq(raiseError.errorClass, newErrorParams)) - case framelessOffsetWindow @ (_: Lag | _: Lead) => val Seq(input, offset, default) = framelessOffsetWindow.children val Seq(newInput, newDefault) = collateToSingleType(Seq(input, default)) @@ -219,6 +212,11 @@ object CollationTypeCoercion { */ private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { + // if there are no collated types we don't need to do anything + return None + } else if (ResolveDefaultStringTypes.needsResolution(expressions)) { + // if any of the strings types are still not resolved + // we need to wait for them to be resolved first return None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala new file mode 100644 index 0000000000000..75958ff3e1177 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.types.{DataType, StringType} + +/** + * Resolves default string types in queries and commands. For queries, the default string type is + * determined by the session's default string type. For DDL, the default string type is the + * default type of the object (table -> schema -> catalog). However, this is not implemented yet. + * So, we will just use UTF8_BINARY for now. + */ +object ResolveDefaultStringTypes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val newPlan = apply0(plan) + if (plan.ne(newPlan)) { + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to transform the plan twice + // to ensure the correct results for occurrences of default string type. + val finalPlan = apply0(newPlan) + RuleExecutor.forceAdditionalIteration(finalPlan) + finalPlan + } else { + newPlan + } + } + + private def apply0(plan: LogicalPlan): LogicalPlan = { + if (isDDLCommand(plan)) { + transformDDL(plan) + } else { + transformPlan(plan, sessionDefaultStringType) + } + } + + /** + * Returns whether any of the given `plan` needs to have its + * default string type resolved. + */ + def needsResolution(plan: LogicalPlan): Boolean = { + if (!isDDLCommand(plan) && isDefaultSessionCollationUsed) { + return false + } + + plan.exists(node => needsResolution(node.expressions)) + } + + /** + * Returns whether any of the given `expressions` needs to have its + * default string type resolved. + */ + def needsResolution(expressions: Seq[Expression]): Boolean = { + expressions.exists(needsResolution) + } + + /** + * Returns whether the given `expression` needs to have its + * default string type resolved. + */ + def needsResolution(expression: Expression): Boolean = { + expression.exists(e => transformExpression.isDefinedAt(e)) + } + + private def isDefaultSessionCollationUsed: Boolean = conf.defaultStringType == StringType + + /** + * Returns the default string type that should be used in a given DDL command (for now always + * UTF8_BINARY). + */ + private def stringTypeForDDLCommand(table: LogicalPlan): StringType = + StringType("UTF8_BINARY") + + /** Returns the session default string type */ + private def sessionDefaultStringType: StringType = + StringType(conf.defaultStringType.collationId) + + private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { + case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _ => isCreateOrAlterPlan(plan) + } + + private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { + case _: V1CreateTablePlan | _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true + case _ => false + } + + private def transformDDL(plan: LogicalPlan): LogicalPlan = { + val newType = stringTypeForDDLCommand(plan) + + plan resolveOperators { + case p if isCreateOrAlterPlan(p) => + transformPlan(p, newType) + + case addCols: AddColumns => + addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType)) + + case replaceCols: ReplaceColumns => + replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType)) + + case alter: AlterColumn + if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => + alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType))) + } + } + + /** + * Transforms the given plan, by transforming all expressions in its operators to use the given + * new type instead of the default string type. + */ + private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { + plan resolveExpressionsUp { expression => + transformExpression + .andThen(_.apply(newType)) + .applyOrElse(expression, identity[Expression]) + } + } + + /** + * Transforms the given expression, by changing all default string types to the given new type. + */ + private def transformExpression: PartialFunction[Expression, StringType => Expression] = { + case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => + newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + + case cast: Cast if hasDefaultStringType(cast.dataType) => + newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + + case Literal(value, dt) if hasDefaultStringType(dt) => + newType => Literal(value, replaceDefaultStringType(dt, newType)) + } + + private def hasDefaultStringType(dataType: DataType): Boolean = + dataType.existsRecursively(isDefaultStringType) + + private def isDefaultStringType(dataType: DataType): Boolean = { + dataType match { + case st: StringType => + // should only return true for StringType object and not StringType("UTF8_BINARY") + st.eq(StringType) || st.isInstanceOf[TemporaryStringType] + case _ => false + } + } + + private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { + dataType.transformRecursively { + case currentType: StringType if isDefaultStringType(currentType) => + if (currentType == newType) { + TemporaryStringType() + } else { + newType + } + } + } + + private def replaceColumnTypes( + colTypes: Seq[QualifiedColType], + newType: StringType): Seq[QualifiedColType] = { + colTypes.map { + case colWithDefault if hasDefaultStringType(colWithDefault.dataType) => + val replaced = replaceDefaultStringType(colWithDefault.dataType, newType) + colWithDefault.copy(dataType = replaced) + + case col => col + } + } +} + +case class TemporaryStringType() extends StringType(1) { + override def toString: String = s"TemporaryStringType($collationId)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 62f3997491c07..b9e9e49a39647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,8 +29,12 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if table.expressionsResolved => + case table: UnresolvedInlineTable if canResolveTable(table) => EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table) } } + + private def canResolveTable(table: UnresolvedInlineTable): Boolean = { + table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala index 5b4d76a2a73ed..3fc4b71c986ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala @@ -415,7 +415,7 @@ abstract class TypeCoercionHelper { if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } c.copy(children = newChildren) case other => other @@ -465,7 +465,7 @@ abstract class TypeCoercionHelper { if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { children.tail.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } } else { children.tail diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5f1b3dc0a01ac..622a0e0aa5bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)) override def left: Expression = errorClass override def right: Expression = errorParms diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c97920619ba4d..2ea53350fea36 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1863,7 +1863,7 @@ trait PadExpressionBuilderBase extends ExpressionBuilder { BinaryPad(funcName, expressions(0), expressions(1), Literal(Array[Byte](0))) } else { createStringPad(expressions(0), - expressions(1), Literal.create(" ", SQLConf.get.defaultStringType)) + expressions(1), Literal(" ")) } } else if (numArgs == 3) { if (expressions(0).dataType == BinaryType && expressions(2).dataType == BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d558689a5c196..3d74e9d314d57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2150,7 +2150,7 @@ class AstBuilder extends DataTypeAstBuilder } val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + val table = if (canEagerlyEvaluateInlineTable(ctx, unresolvedTable)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2158,6 +2158,42 @@ class AstBuilder extends DataTypeAstBuilder table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } + /** + * Determines if the inline table can be eagerly evaluated. + */ + private def canEagerlyEvaluateInlineTable( + ctx: InlineTableContext, + table: UnresolvedInlineTable): Boolean = { + if (!conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + return false + } else if (!ResolveDefaultStringTypes.needsResolution(table.expressions)) { + // if there are no strings to be resolved we can always evaluate eagerly + return true + } + + val isSessionCollationSet = conf.defaultStringType != StringType + + // if either of these are true we need to resolve + // the string types first + !isSessionCollationSet && !contextInsideCreate(ctx) + } + + private def contextInsideCreate(ctx: ParserRuleContext): Boolean = { + var currentContext: RuleContext = ctx + + while (currentContext != null) { + if (currentContext.isInstanceOf[CreateTableContext] || + currentContext.isInstanceOf[ReplaceTableContext] || + currentContext.isInstanceOf[CreateViewContext]) { + return true + } + + currentContext = currentContext.parent + } + + false + } + /** * Create an alias (SubqueryAlias) for a join relation. This is practically the same as * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different @@ -3369,7 +3405,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), conf.defaultStringType) + Literal.create(createString(ctx), StringType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b465e0e11612f..857522728eaff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,12 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +/** + * A trait used for logical plan nodes that create V1 table definitions, + * and so that rules from the catalyst module can identify them. + */ +trait V1CreateTablePlan extends LogicalPlan + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 76d36fab2096a..bdbf698db2e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,7 +22,8 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.rules.RuleExecutor.getForceIterationValue +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors @@ -30,6 +31,27 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { + + /** + * A tag used to explicitly request an additional iteration of the current batch during + * rule execution, even if the query plan remains unchanged. Increment the tag's value + * to enforce another iteration. + */ + private val FORCE_ADDITIONAL_ITERATION = TreeNodeTag[Int]("forceAdditionalIteration") + + /** + * Increments the value of the FORCE_ADDITIONAL_ITERATION tag on the given plan to + * explicitly force another iteration of the current batch during rule execution. + */ + def forceAdditionalIteration(plan: TreeNode[_]): Unit = { + val oldValue = getForceIterationValue(plan) + plan.setTagValue(FORCE_ADDITIONAL_ITERATION, oldValue + 1) + } + + private def getForceIterationValue(plan: TreeNode[_]): Int = { + plan.getTagValue(FORCE_ADDITIONAL_ITERATION).getOrElse(0) + } + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ @@ -303,7 +325,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan)) { + if (isFixedPointReached(lastPlan, curPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false @@ -317,4 +339,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + private def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { + oldPlan.fastEquals(newPlan) && + getForceIterationValue(newPlan) <= getForceIterationValue(oldPlan) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 378eca09097f5..e8031580c1165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5580,7 +5580,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { StringType } else { - StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION))) + StringType(getConf(DEFAULT_COLLATION)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f6d8f2a66e202..7250b6e2b90e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -715,7 +715,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_LCASE"), expected = true) checkEqualsIgnoreCompatibleCollation( - StringType("UTF8_BINARY"), StringType("UTF8_LCASE"), expected = true) + StringType("UTF8_LCASE"), StringType("UTF8_BINARY"), expected = true) // Complex types. checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 92c74f7bede18..5f1ab089cf3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -54,6 +54,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case _ if ResolveDefaultStringTypes.needsResolution(plan) => + // if there are still unresolved string types in the plan + // we should not try to resolve it + plan + case AddColumns(ResolvedV1TableIdentifier(ident), cols) => cols.foreach { c => if (c.name.length > 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d9367d92d462e..eb9d5813cff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan { + query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 6feb4587b816f..cf494fcd87451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -376,6 +376,10 @@ class CollationSQLExpressionsSuite StructField("B", DoubleType, nullable = true) )), CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "", + Row("Spark"), Seq( + StructField("a", StringType, nullable = true) + )), + CsvToStructsTestCase("\"Spark\"", "UTF8_BINARY", "'a STRING COLLATE UNICODE'", "", Row("Spark"), Seq( StructField("a", StringType("UNICODE"), nullable = true) )), @@ -1291,6 +1295,10 @@ class CollationSQLExpressionsSuite StructField("B", DoubleType, nullable = true) )), XmlToStructsTestCase("

Spark

", "UNICODE", "'s STRING'", "", + Row("Spark"), Seq( + StructField("s", StringType, nullable = true) + )), + XmlToStructsTestCase("

Spark

", "UTF8_BINARY", "'s STRING COLLATE UNICODE'", "", Row("Spark"), Seq( StructField("s", StringType("UNICODE"), nullable = true) )), @@ -1515,8 +1523,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY", 1, IntegerType), VariantGetTestCase("{\"a\": 1}", "$.b", "int", "UTF8_LCASE", null, IntegerType), - VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", StringType("UNICODE")), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", + StringType), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string collate unicode", "UTF8_BINARY", "2", + StringType("UNICODE")), VariantGetTestCase("[1, \"2\"]", "$[2]", "string", "UNICODE_CI", null, + StringType), + VariantGetTestCase("[1, \"2\"]", "$[2]", "string collate unicode_CI", "UTF8_BINARY", null, StringType("UNICODE_CI")) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 5bb8511d0d935..7cafb999ffcf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -111,15 +111,17 @@ class CollationSQLRegexpSuite } val tableNameLcase = "T_LCASE" withTable(tableNameLcase) { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { - sql(s"CREATE TABLE IF NOT EXISTS $tableNameLcase(c STRING) using PARQUET") - sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") - checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) - checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) - } + sql(s""" + |CREATE TABLE IF NOT EXISTS $tableNameLcase( + | c STRING COLLATE UTF8_LCASE + |) using PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") + checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 2a0b84c075079..626bd0b239366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -198,8 +198,8 @@ class CollationStringExpressionsSuite checkError( exception = intercept[AnalysisException] { val expr = StringSplitSQL( - Collate(Literal.create("1a2"), "UTF8_BINARY"), - Collate(Literal.create("a"), "UTF8_LCASE")) + Collate(Literal.create("1a2", StringType("UTF8_BINARY")), "UTF8_BINARY"), + Collate(Literal.create("a", StringType("UTF8_BINARY")), "UTF8_LCASE")) CollationTypeCasts.transform(expr) }, condition = "COLLATION_MISMATCH.EXPLICIT", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 1707820053837..f0f81e713457b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1096,29 +1096,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("SPARK-47431: Default collation set to UNICODE, column type test") { - withTable("t") { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - sql(s"CREATE TABLE t(c1 STRING) USING PARQUET") - sql(s"INSERT INTO t VALUES ('a')") - checkAnswer(sql(s"SELECT collation(c1) FROM t"), Seq(Row("UNICODE"))) - } - } - } - - test("SPARK-47431: Create table with UTF8_BINARY, make sure collation persists on read") { - withTable("t") { - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY") { - sql("CREATE TABLE t(c1 STRING) USING PARQUET") - sql("INSERT INTO t VALUES ('a')") - checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY"))) - } - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY"))) - } - } - } - test("Create dataframe with non utf8 binary collation") { val schema = StructType(Seq(StructField("Name", StringType("UNICODE_CI")))) val data = Seq(Row("Alice"), Row("Bob"), Row("bob")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala new file mode 100644 index 0000000000000..0de638d4e9bf9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.collation + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.connector.DatasourceV2SQLBase +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSession { + + def dataSource: String = "parquet" + def testTable: String = "test_tbl" + def testView: String = "test_view" + + def withSessionCollationAndTable(collation: String, testTables: String*)(f: => Unit): Unit = { + withTable(testTables: _*) { + withSessionCollation(collation) { + f + } + } + } + + def withSessionCollationAndView(collation: String, viewNames: String*)(f: => Unit): Unit = { + withView(viewNames: _*) { + withSessionCollation(collation) { + f + } + } + } + + def withSessionCollation(collation: String)(f: => Unit): Unit = { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { + f + } + } + + def assertTableColumnCollation( + table: String, + column: String, + expectedCollation: String): Unit = { + val colType = spark.table(table).schema(column).dataType + assert(colType === StringType(expectedCollation)) + } + + def assertThrowsImplicitMismatch(f: => DataFrame): Unit = { + val exception = intercept[AnalysisException] { + f + } + assert(exception.getCondition === "COLLATION_MISMATCH.IMPLICIT") + } + + // region DDL tests + + test("create/alter table") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + // create table with implicit collation + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + + // alter table add column with implicit collation + sql(s"ALTER TABLE $testTable ADD COLUMN c2 STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + + sql(s"ALTER TABLE $testTable ALTER COLUMN c2 TYPE STRING COLLATE UNICODE") + assertTableColumnCollation(testTable, "c2", "UNICODE") + + sql(s"ALTER TABLE $testTable ALTER COLUMN c2 TYPE STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + } + } + + test("create table with explicit collation") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_LCASE") + } + + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UNICODE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UNICODE") + } + } + + test("create table as select") { + // literals in select do not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS SELECT + | 'a' AS c1, + | 'a' || 'a' AS c2, + | SUBSTRING('a', 1, 1) AS c3, + | SUBSTRING(SUBSTRING('ab', 1, 1), 1, 1) AS c4, + | 'a' = 'A' AS truthy + |""".stripMargin) + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c3", "UTF8_BINARY") + assertTableColumnCollation(testTable, "c4", "UTF8_BINARY") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE truthy"), Seq(Row(0))) + } + + // literals in inline table do not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT c1, c1 = 'A' as c2 FROM VALUES ('a'), ('A') AS vals(c1) + |""".stripMargin) + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(1))) + } + + // cast in select does not pick up session collation + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable USING $dataSource AS SELECT cast('a' AS STRING) AS c1") + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + } + } + + test("ctas with complex types") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT + | struct('a') AS c1, + | map('a', 'b') AS c2, + | array('a') AS c3 + |""".stripMargin) + + checkAnswer(sql(s"SELECT COLLATION(c1.col1) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c2['a']) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + checkAnswer(sql(s"SELECT COLLATION(c3[0]) FROM $testTable"), Seq(Row("UTF8_BINARY"))) + } + } + + test("ctas with union") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row(false))) + } + + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable USING $dataSource AS + |SELECT 'a' = 'A' AS c1 + |UNION ALL + |SELECT 'b' = 'B' AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row(false), Row(false))) + } + } + + test("add column") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + assertTableColumnCollation(testTable, "c1", "UTF8_LCASE") + + sql(s"ALTER TABLE $testTable ADD COLUMN c2 STRING") + assertTableColumnCollation(testTable, "c2", "UTF8_BINARY") + + sql(s"ALTER TABLE $testTable ADD COLUMN c3 STRING COLLATE UNICODE") + assertTableColumnCollation(testTable, "c3", "UNICODE") + } + } + + test("inline table in CTAS") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource + |AS SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |AS inline_table(c1, c2); + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } + + test("subsequent analyzer iterations correctly resolve default string types") { + // since concat coercion happens after resolving default types this test + // makes sure that we are correctly resolving the default string types + // in subsequent analyzer iterations + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource AS + |SELECT CONCAT(X'68656C6C6F', 'world') AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT c1 FROM $testTable"), Seq(Row("helloworld"))) + } + + // ELT is similar + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s""" + |CREATE TABLE $testTable + |USING $dataSource AS + |SELECT ELT(1, X'68656C6C6F', 'world') AS c1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT c1 FROM $testTable"), Seq(Row("hello"))) + } + } + + // endregion + + // region DML tests + + test("literals with default collation") { + val sessionCollation = "UTF8_LCASE" + withSessionCollation(sessionCollation) { + + // literal without collation + checkAnswer(sql("SELECT COLLATION('a')"), Seq(Row(sessionCollation))) + + checkAnswer(sql("SELECT COLLATION(map('a', 'b')['a'])"), Seq(Row(sessionCollation))) + + checkAnswer(sql("SELECT COLLATION(array('a')[0])"), Seq(Row(sessionCollation))) + + checkAnswer(sql("SELECT COLLATION(struct('a' as c)['c'])"), Seq(Row(sessionCollation))) + } + } + + test("literals with explicit collation") { + withSessionCollation("UTF8_LCASE") { + checkAnswer(sql("SELECT COLLATION('a' collate unicode)"), Seq(Row("UNICODE"))) + + checkAnswer( + sql("SELECT COLLATION(map('a', 'b' collate unicode)['a'])"), + Seq(Row("UNICODE"))) + + checkAnswer(sql("SELECT COLLATION(array('a' collate unicode)[0])"), Seq(Row("UNICODE"))) + + checkAnswer( + sql("SELECT COLLATION(struct('a' collate unicode as c)['c'])"), + Seq(Row("UNICODE"))) + } + } + + test("cast is aware of session collation") { + val sessionCollation = "UTF8_LCASE" + withSessionCollation(sessionCollation) { + checkAnswer(sql("SELECT COLLATION(cast('a' as STRING))"), Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(cast(map('a', 'b') as MAP)['a'])"), + Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(map_keys(cast(map('a', 'b') as MAP))[0])"), + Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(cast(array('a') as ARRAY)[0])"), + Seq(Row(sessionCollation))) + + checkAnswer( + sql("SELECT COLLATION(cast(struct('a' as c) as STRUCT)['c'])"), + Seq(Row(sessionCollation))) + } + } + + test("expressions in where are aware of session collation") { + withSessionCollation("UTF8_LCASE") { + // expression in where is aware of session collation + checkAnswer(sql("SELECT 1 WHERE 'a' = 'A'"), Seq(Row(1))) + + checkAnswer(sql("SELECT 1 WHERE 'a' = cast('A' as STRING)"), Seq(Row(1))) + } + } + + test("having group by is aware of session collation") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a'), ('A')") + + // having clause uses session (default) collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testTable GROUP BY c1 HAVING 'a' = 'A'"), + Seq(Row(1), Row(1))) + + // having clause uses column (implicit) collation + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testTable GROUP BY c1 HAVING c1 = 'A'"), + Seq(Row(1))) + } + } + + test("min/max are aware of session collation") { + // scalastyle:off nonascii + withSessionCollationAndTable("UNICODE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('1'), ('½')") + + checkAnswer(sql(s"SELECT MIN(c1) FROM $testTable"), Seq(Row("1"))) + + checkAnswer(sql(s"SELECT MAX(c1) FROM $testTable"), Seq(Row("½"))) + } + // scalastyle:on nonascii + } + + test("union operation with subqueries") { + withSessionCollation("UTF8_LCASE") { + checkAnswer( + sql(s""" + |SELECT 'a' = 'A' + |UNION + |SELECT 'b' = 'B' + |""".stripMargin), + Seq(Row(true))) + + checkAnswer( + sql(s""" + |SELECT 'a' = 'A' + |UNION ALL + |SELECT 'b' = 'B' + |""".stripMargin), + Seq(Row(true), Row(true))) + } + } + + test("inline table in SELECT") { + withSessionCollation("UTF8_LCASE") { + val df = s""" + |SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |""".stripMargin + + checkAnswer(sql(df), Seq(Row("a", true))) + } + } + + test("inline table in insert") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") + + sql(s"INSERT INTO $testTable VALUES ('a', 'a' = 'A')") + checkAnswer(sql(s"SELECT * FROM $testTable"), Seq(Row("a", true))) + } + } + + test("literals in insert inherit session level collation") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 BOOLEAN) USING $dataSource") + + sql(s"INSERT INTO $testTable VALUES ('a' = 'A')") + sql(s"INSERT INTO $testTable VALUES (array_contains(array('a'), 'A'))") + sql(s"INSERT INTO $testTable VALUES (CONCAT(X'68656C6C6F', 'world') = 'HELLOWORLD')") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c1"), Seq(Row(3))) + } + } + + // endregion +} + +class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite { + + test("create/alter view created from a table") { + val sessionCollation = "UTF8_LCASE" + withSessionCollationAndTable(sessionCollation, testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a', 'a'), ('A', 'A')") + + withView(testView) { + sql(s"CREATE VIEW $testView AS SELECT * FROM $testTable") + + assertTableColumnCollation(testView, "c1", "UTF8_BINARY") + assertTableColumnCollation(testView, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $testView"), + Row("UTF8_BINARY", sessionCollation)) + + // filter should use session collation + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE 'a' = 'A'"), Row(2)) + + // filter should use column collation + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(1)) + + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = substring('A', 0, 1)"), + Row(1)) + + // literal with explicit collation wins + checkAnswer( + sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A' collate UNICODE_CI"), + Row(2)) + + // two implicit collations -> errors out + assertThrowsImplicitMismatch(sql(s"SELECT c1 = c2 FROM $testView")) + + sql(s"ALTER VIEW $testView AS SELECT c1 COLLATE UNICODE_CI AS c1, c2 FROM $testTable") + assertTableColumnCollation(testView, "c1", "UNICODE_CI") + assertTableColumnCollation(testView, "c2", "UNICODE_CI") + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1), COLLATION('a') FROM $testView"), + Row("UNICODE_CI", sessionCollation)) + + // after alter both rows should be returned + checkAnswer(sql(s"SELECT COUNT(*) FROM $testView WHERE c1 = 'A'"), Row(2)) + } + } + } + + test("join view with table") { + val viewTableName = "view_table" + val joinTableName = "join_table" + val sessionCollation = "sr" + + withSessionCollationAndTable(sessionCollation, viewTableName, joinTableName) { + sql(s"CREATE TABLE $viewTableName (c1 STRING COLLATE UNICODE_CI) USING $dataSource") + sql(s"CREATE TABLE $joinTableName (c1 STRING COLLATE UTF8_LCASE) USING $dataSource") + sql(s"INSERT INTO $viewTableName VALUES ('a')") + sql(s"INSERT INTO $joinTableName VALUES ('A')") + + withView(testView) { + sql(s"CREATE VIEW $testView AS SELECT * FROM $viewTableName") + + assertThrowsImplicitMismatch( + sql(s"SELECT * FROM $testView JOIN $joinTableName ON $testView.c1 = $joinTableName.c1")) + + checkAnswer( + sql(s""" + |SELECT COLLATION($testView.c1), COLLATION($joinTableName.c1) + |FROM $testView JOIN $joinTableName + |ON $testView.c1 = $joinTableName.c1 COLLATE UNICODE_CI + |""".stripMargin), + Row("UNICODE_CI", "UTF8_LCASE")) + } + } + } +} + +class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with DatasourceV2SQLBase { + override def testTable: String = s"testcat.${super.testTable}" + override def testView: String = s"testcat.${super.testView}" + + // delete only works on v2 + test("delete behavior") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING) USING $dataSource") + sql(s"INSERT INTO $testTable VALUES ('a'), ('A')") + + sql(s"DELETE FROM $testTable WHERE 'a' = 'A'") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable"), Seq(Row(0))) + } + } + + test("inline table in RTAS") { + withSessionCollationAndTable("UTF8_LCASE", testTable) { + sql(s"CREATE TABLE $testTable (c1 STRING, c2 BOOLEAN) USING $dataSource") + sql(s""" + |REPLACE TABLE $testTable + |USING $dataSource + |AS SELECT * + |FROM (VALUES ('a', 'a' = 'A')) + |AS inline_table(c1, c2); + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UTF8_BINARY") + checkAnswer(sql(s"SELECT COUNT(*) FROM $testTable WHERE c2"), Seq(Row(0))) + } + } +}