diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 71e8517a4164e..c2cb4a7154076 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -76,7 +76,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => SqlApiConf.get.defaultStringType + case Seq(_) => StringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 49d8bf9e001ab..6dcb8a876b7a2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.internal.types -import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** @@ -26,7 +25,7 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} abstract class AbstractStringType(supportsTrimCollation: Boolean = false) extends AbstractDataType with Serializable { - override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType + override private[sql] def defaultConcreteType: DataType = StringType override private[sql] def simpleString: String = "string" override private[sql] def acceptsType(other: DataType): Boolean = other match { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index fc32248b4baf3..53dfc5e9b2828 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -110,4 +110,13 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + ArrayType(elementType.transformRecursively(f), containsNull) + } + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 036de22b4189a..12cfed5b58685 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -105,6 +105,13 @@ abstract class DataType extends AbstractDataType { */ private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + /** + * Recursively applies the provided partial function `f` to transform this DataType tree. + */ + private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) f(this) else this + } + final override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala index 1dfb9aaf9e29b..de656c13ca4bf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -89,6 +89,18 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + f(this) + } else { + MapType( + keyType.transformRecursively(f), + valueType.transformRecursively(f), + valueContainsNull) + } + } } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1eb645e37c4aa..b2cf502f8bdc1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * The id of collation for this StringType. */ @Stable -class StringType private (val collationId: Int) extends AtomicType with Serializable { +class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if they are byte @@ -75,7 +75,14 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ */ override def typeName: String = if (isUTF8BinaryCollation) "string" - else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + else s"string collate $collationName" + + override def toString: String = + if (isUTF8BinaryCollation) "StringType" + else s"StringType($collationName)" + + private[sql] def collationName: String = + CollationFactory.fetchCollation(collationId).collationName // Due to backwards compatibility and compatibility with other readers // all string types are serialized in json as regular strings and diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 07f6b50bd4a7a..cc95d8ee94b02 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -502,6 +502,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } + + override private[spark] def transformRecursively( + f: PartialFunction[DataType, DataType]): DataType = { + if (f.isDefinedAt(this)) { + return f(this) + } + + val newFields = fields.map { field => + field.copy(dataType = field.dataType.transformRecursively(f)) + } + StructType(newFields) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84b3ca2289f4c..8e1b9da927c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAliases :: ResolveSubquery :: ResolveSubqueryColumnAliases :: + ResolveDefaultStringTypes :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 532e5e0d0a066..cca1d21df3a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} import org.apache.spark.sql.util.SchemaUtils /** @@ -93,13 +93,6 @@ object CollationTypeCoercion { val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad)) stringPadExpr.withNewChildren(Seq(newStr, len, newPad)) - case raiseError: RaiseError => - val newErrorParams = raiseError.errorParms.dataType match { - case MapType(StringType, StringType, _) => raiseError.errorParms - case _ => Cast(raiseError.errorParms, MapType(StringType, StringType)) - } - raiseError.withNewChildren(Seq(raiseError.errorClass, newErrorParams)) - case framelessOffsetWindow @ (_: Lag | _: Lead) => val Seq(input, offset, default) = framelessOffsetWindow.children val Seq(newInput, newDefault) = collateToSingleType(Seq(input, default)) @@ -219,6 +212,11 @@ object CollationTypeCoercion { */ private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = { if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) { + // if there are no collated types we don't need to do anything + return None + } else if (ResolveDefaultStringTypes.needsResolution(expressions)) { + // if any of the strings types are still not resolved + // we need to wait for them to be resolved first return None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala new file mode 100644 index 0000000000000..75958ff3e1177 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.types.{DataType, StringType} + +/** + * Resolves default string types in queries and commands. For queries, the default string type is + * determined by the session's default string type. For DDL, the default string type is the + * default type of the object (table -> schema -> catalog). However, this is not implemented yet. + * So, we will just use UTF8_BINARY for now. + */ +object ResolveDefaultStringTypes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val newPlan = apply0(plan) + if (plan.ne(newPlan)) { + // Due to how tree transformations work and StringType object being equal to + // StringType("UTF8_BINARY"), we need to transform the plan twice + // to ensure the correct results for occurrences of default string type. + val finalPlan = apply0(newPlan) + RuleExecutor.forceAdditionalIteration(finalPlan) + finalPlan + } else { + newPlan + } + } + + private def apply0(plan: LogicalPlan): LogicalPlan = { + if (isDDLCommand(plan)) { + transformDDL(plan) + } else { + transformPlan(plan, sessionDefaultStringType) + } + } + + /** + * Returns whether any of the given `plan` needs to have its + * default string type resolved. + */ + def needsResolution(plan: LogicalPlan): Boolean = { + if (!isDDLCommand(plan) && isDefaultSessionCollationUsed) { + return false + } + + plan.exists(node => needsResolution(node.expressions)) + } + + /** + * Returns whether any of the given `expressions` needs to have its + * default string type resolved. + */ + def needsResolution(expressions: Seq[Expression]): Boolean = { + expressions.exists(needsResolution) + } + + /** + * Returns whether the given `expression` needs to have its + * default string type resolved. + */ + def needsResolution(expression: Expression): Boolean = { + expression.exists(e => transformExpression.isDefinedAt(e)) + } + + private def isDefaultSessionCollationUsed: Boolean = conf.defaultStringType == StringType + + /** + * Returns the default string type that should be used in a given DDL command (for now always + * UTF8_BINARY). + */ + private def stringTypeForDDLCommand(table: LogicalPlan): StringType = + StringType("UTF8_BINARY") + + /** Returns the session default string type */ + private def sessionDefaultStringType: StringType = + StringType(conf.defaultStringType.collationId) + + private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists { + case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true + case _ => isCreateOrAlterPlan(plan) + } + + private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match { + case _: V1CreateTablePlan | _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true + case _ => false + } + + private def transformDDL(plan: LogicalPlan): LogicalPlan = { + val newType = stringTypeForDDLCommand(plan) + + plan resolveOperators { + case p if isCreateOrAlterPlan(p) => + transformPlan(p, newType) + + case addCols: AddColumns => + addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType)) + + case replaceCols: ReplaceColumns => + replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType)) + + case alter: AlterColumn + if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) => + alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType))) + } + } + + /** + * Transforms the given plan, by transforming all expressions in its operators to use the given + * new type instead of the default string type. + */ + private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { + plan resolveExpressionsUp { expression => + transformExpression + .andThen(_.apply(newType)) + .applyOrElse(expression, identity[Expression]) + } + } + + /** + * Transforms the given expression, by changing all default string types to the given new type. + */ + private def transformExpression: PartialFunction[Expression, StringType => Expression] = { + case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) => + newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType)) + + case cast: Cast if hasDefaultStringType(cast.dataType) => + newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType)) + + case Literal(value, dt) if hasDefaultStringType(dt) => + newType => Literal(value, replaceDefaultStringType(dt, newType)) + } + + private def hasDefaultStringType(dataType: DataType): Boolean = + dataType.existsRecursively(isDefaultStringType) + + private def isDefaultStringType(dataType: DataType): Boolean = { + dataType match { + case st: StringType => + // should only return true for StringType object and not StringType("UTF8_BINARY") + st.eq(StringType) || st.isInstanceOf[TemporaryStringType] + case _ => false + } + } + + private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = { + dataType.transformRecursively { + case currentType: StringType if isDefaultStringType(currentType) => + if (currentType == newType) { + TemporaryStringType() + } else { + newType + } + } + } + + private def replaceColumnTypes( + colTypes: Seq[QualifiedColType], + newType: StringType): Seq[QualifiedColType] = { + colTypes.map { + case colWithDefault if hasDefaultStringType(colWithDefault.dataType) => + val replaced = replaceDefaultStringType(colWithDefault.dataType, newType) + colWithDefault.copy(dataType = replaced) + + case col => col + } + } +} + +case class TemporaryStringType() extends StringType(1) { + override def toString: String = s"TemporaryStringType($collationId)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 62f3997491c07..b9e9e49a39647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,8 +29,12 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if table.expressionsResolved => + case table: UnresolvedInlineTable if canResolveTable(table) => EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table) } } + + private def canResolveTable(table: UnresolvedInlineTable): Boolean = { + table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala index 5b4d76a2a73ed..3fc4b71c986ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala @@ -415,7 +415,7 @@ abstract class TypeCoercionHelper { if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } c.copy(children = newChildren) case other => other @@ -465,7 +465,7 @@ abstract class TypeCoercionHelper { if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { children.tail.map { e => - implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) + implicitCast(e, StringType).getOrElse(e) } } else { children.tail diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5f1b3dc0a01ac..622a0e0aa5bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)) override def left: Expression = errorClass override def right: Expression = errorParms diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c97920619ba4d..2ea53350fea36 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1863,7 +1863,7 @@ trait PadExpressionBuilderBase extends ExpressionBuilder { BinaryPad(funcName, expressions(0), expressions(1), Literal(Array[Byte](0))) } else { createStringPad(expressions(0), - expressions(1), Literal.create(" ", SQLConf.get.defaultStringType)) + expressions(1), Literal(" ")) } } else if (numArgs == 3) { if (expressions(0).dataType == BinaryType && expressions(2).dataType == BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d558689a5c196..3d74e9d314d57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2150,7 +2150,7 @@ class AstBuilder extends DataTypeAstBuilder } val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) - val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + val table = if (canEagerlyEvaluateInlineTable(ctx, unresolvedTable)) { EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) } else { unresolvedTable @@ -2158,6 +2158,42 @@ class AstBuilder extends DataTypeAstBuilder table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } + /** + * Determines if the inline table can be eagerly evaluated. + */ + private def canEagerlyEvaluateInlineTable( + ctx: InlineTableContext, + table: UnresolvedInlineTable): Boolean = { + if (!conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + return false + } else if (!ResolveDefaultStringTypes.needsResolution(table.expressions)) { + // if there are no strings to be resolved we can always evaluate eagerly + return true + } + + val isSessionCollationSet = conf.defaultStringType != StringType + + // if either of these are true we need to resolve + // the string types first + !isSessionCollationSet && !contextInsideCreate(ctx) + } + + private def contextInsideCreate(ctx: ParserRuleContext): Boolean = { + var currentContext: RuleContext = ctx + + while (currentContext != null) { + if (currentContext.isInstanceOf[CreateTableContext] || + currentContext.isInstanceOf[ReplaceTableContext] || + currentContext.isInstanceOf[CreateViewContext]) { + return true + } + + currentContext = currentContext.parent + } + + false + } + /** * Create an alias (SubqueryAlias) for a join relation. This is practically the same as * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different @@ -3369,7 +3405,7 @@ class AstBuilder extends DataTypeAstBuilder * Create a String literal expression. */ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal.create(createString(ctx), conf.defaultStringType) + Literal.create(createString(ctx), StringType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b465e0e11612f..857522728eaff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -459,6 +459,12 @@ trait V2CreateTableAsSelectPlan newQuery: LogicalPlan): V2CreateTableAsSelectPlan } +/** + * A trait used for logical plan nodes that create V1 table definitions, + * and so that rules from the catalyst module can identify them. + */ +trait V1CreateTablePlan extends LogicalPlan + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def name: LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 76d36fab2096a..bdbf698db2e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,7 +22,8 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.rules.RuleExecutor.getForceIterationValue +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors @@ -30,6 +31,27 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { + + /** + * A tag used to explicitly request an additional iteration of the current batch during + * rule execution, even if the query plan remains unchanged. Increment the tag's value + * to enforce another iteration. + */ + private val FORCE_ADDITIONAL_ITERATION = TreeNodeTag[Int]("forceAdditionalIteration") + + /** + * Increments the value of the FORCE_ADDITIONAL_ITERATION tag on the given plan to + * explicitly force another iteration of the current batch during rule execution. + */ + def forceAdditionalIteration(plan: TreeNode[_]): Unit = { + val oldValue = getForceIterationValue(plan) + plan.setTagValue(FORCE_ADDITIONAL_ITERATION, oldValue + 1) + } + + private def getForceIterationValue(plan: TreeNode[_]): Int = { + plan.getTagValue(FORCE_ADDITIONAL_ITERATION).getOrElse(0) + } + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ @@ -303,7 +325,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { continue = false } - if (curPlan.fastEquals(lastPlan)) { + if (isFixedPointReached(lastPlan, curPlan)) { logTrace( s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.") continue = false @@ -317,4 +339,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + private def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { + oldPlan.fastEquals(newPlan) && + getForceIterationValue(newPlan) <= getForceIterationValue(oldPlan) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 378eca09097f5..e8031580c1165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5580,7 +5580,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { StringType } else { - StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION))) + StringType(getConf(DEFAULT_COLLATION)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f6d8f2a66e202..7250b6e2b90e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -715,7 +715,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_LCASE"), expected = true) checkEqualsIgnoreCompatibleCollation( - StringType("UTF8_BINARY"), StringType("UTF8_LCASE"), expected = true) + StringType("UTF8_LCASE"), StringType("UTF8_BINARY"), expected = true) // Complex types. checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 92c74f7bede18..5f1ab089cf3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -54,6 +54,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case _ if ResolveDefaultStringTypes.needsResolution(plan) => + // if there are still unresolved string types in the plan + // we should not try to resolve it + plan + case AddColumns(ResolvedV1TableIdentifier(ident), cols) => cols.foreach { c => if (c.name.length > 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d9367d92d462e..eb9d5813cff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1CreateTablePlan} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ case class CreateTable( tableDesc: CatalogTable, mode: SaveMode, - query: Option[LogicalPlan]) extends LogicalPlan { + query: Option[LogicalPlan]) extends LogicalPlan with V1CreateTablePlan { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 6feb4587b816f..cf494fcd87451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -376,6 +376,10 @@ class CollationSQLExpressionsSuite StructField("B", DoubleType, nullable = true) )), CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "", + Row("Spark"), Seq( + StructField("a", StringType, nullable = true) + )), + CsvToStructsTestCase("\"Spark\"", "UTF8_BINARY", "'a STRING COLLATE UNICODE'", "", Row("Spark"), Seq( StructField("a", StringType("UNICODE"), nullable = true) )), @@ -1291,6 +1295,10 @@ class CollationSQLExpressionsSuite StructField("B", DoubleType, nullable = true) )), XmlToStructsTestCase("
Spark
Spark