Skip to content

Commit

Permalink
[SPARK-49992][SQL] Default collation resolution for DDL and DML queries
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR proposes not using session-level collation in DDL commands (create/alter view/table, add/replace columns).

Also, resolution of default collation should happen in the analyzer and not in the parser. However, due to how we are checking for default string type (using reference equals with `StringType` object) we cannot just replace this object with `StringType("UTF8_BINARY")` because they compare as equal so the tree node framework will just return the old plan. Because of this we have to perform this resolution twice, once by changing the `StringType` object into a `TemporaryStringType` and then back to `StringType("UTF8_BINARY")` which is not considered a default string type anymore.

Another thing is that the dependent rules `ResolveInlineTables` and `CollationTypeCoercion` are updated so that they don't execute if there are still unresolved string types in the plan.

### Why are the changes needed?
The default collation for DDL commands should be tied to the object being created or altered (e.g., table, view, schema) rather than the session-level setting. Since object-level collations are not yet supported, we will assume the UTF8_BINARY collation by default for now.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added new unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48962 from stefankandic/fixSessionCollationOrder.

Lead-authored-by: Stefan Kandic <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
stefankandic and cloud-fan committed Nov 29, 2024
1 parent aaf8590 commit b45045e
Show file tree
Hide file tree
Showing 26 changed files with 855 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
case (TIMESTAMP_LTZ, Nil) => TimestampType
case (STRING, Nil) =>
typeCtx.children.asScala.toSeq match {
case Seq(_) => SqlApiConf.get.defaultStringType
case Seq(_) => StringType
case Seq(_, ctx: CollateClauseContext) =>
val collationName = visitCollateClause(ctx)
val collationId = CollationFactory.collationNameToId(collationName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.internal.types

import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
Expand All @@ -26,7 +25,7 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
abstract class AbstractStringType(supportsTrimCollation: Boolean = false)
extends AbstractDataType
with Serializable {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def defaultConcreteType: DataType = StringType
override private[sql] def simpleString: String = "string"

override private[sql] def acceptsType(other: DataType): Boolean = other match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,13 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || elementType.existsRecursively(f)
}

override private[spark] def transformRecursively(
f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) {
f(this)
} else {
ArrayType(elementType.transformRecursively(f), containsNull)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)

/**
* Recursively applies the provided partial function `f` to transform this DataType tree.
*/
private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) f(this) else this
}

final override private[sql] def defaultConcreteType: DataType = this

override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Bo
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
}

override private[spark] def transformRecursively(
f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) {
f(this)
} else {
MapType(
keyType.transformRecursively(f),
valueType.transformRecursively(f),
valueContainsNull)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory
* The id of collation for this StringType.
*/
@Stable
class StringType private (val collationId: Int) extends AtomicType with Serializable {
class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable {

/**
* Support for Binary Equality implies that strings are considered equal only if they are byte
Expand Down Expand Up @@ -75,7 +75,14 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
*/
override def typeName: String =
if (isUTF8BinaryCollation) "string"
else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}"
else s"string collate $collationName"

override def toString: String =
if (isUTF8BinaryCollation) "StringType"
else s"StringType($collationName)"

private[sql] def collationName: String =
CollationFactory.fetchCollation(collationId).collationName

// Due to backwards compatibility and compatibility with other readers
// all string types are serialized in json as regular strings and
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || fields.exists(field => field.dataType.existsRecursively(f))
}

override private[spark] def transformRecursively(
f: PartialFunction[DataType, DataType]): DataType = {
if (f.isDefinedAt(this)) {
return f(this)
}

val newFields = fields.map { field =>
field.copy(dataType = field.dataType.transformRecursively(f))
}
StructType(newFields)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
ResolveDefaultStringTypes ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
import org.apache.spark.sql.util.SchemaUtils

/**
Expand Down Expand Up @@ -93,13 +93,6 @@ object CollationTypeCoercion {
val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad))
stringPadExpr.withNewChildren(Seq(newStr, len, newPad))

case raiseError: RaiseError =>
val newErrorParams = raiseError.errorParms.dataType match {
case MapType(StringType, StringType, _) => raiseError.errorParms
case _ => Cast(raiseError.errorParms, MapType(StringType, StringType))
}
raiseError.withNewChildren(Seq(raiseError.errorClass, newErrorParams))

case framelessOffsetWindow @ (_: Lag | _: Lead) =>
val Seq(input, offset, default) = framelessOffsetWindow.children
val Seq(newInput, newDefault) = collateToSingleType(Seq(input, default))
Expand Down Expand Up @@ -219,6 +212,11 @@ object CollationTypeCoercion {
*/
private def findLeastCommonStringType(expressions: Seq[Expression]): Option[StringType] = {
if (!expressions.exists(e => SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
// if there are no collated types we don't need to do anything
return None
} else if (ResolveDefaultStringTypes.needsResolution(expressions)) {
// if any of the strings types are still not resolved
// we need to wait for them to be resolved first
return None
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterViewAs, ColumnDefinition, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V1CreateTablePlan, V2CreateTablePlan}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.types.{DataType, StringType}

/**
* Resolves default string types in queries and commands. For queries, the default string type is
* determined by the session's default string type. For DDL, the default string type is the
* default type of the object (table -> schema -> catalog). However, this is not implemented yet.
* So, we will just use UTF8_BINARY for now.
*/
object ResolveDefaultStringTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val newPlan = apply0(plan)
if (plan.ne(newPlan)) {
// Due to how tree transformations work and StringType object being equal to
// StringType("UTF8_BINARY"), we need to transform the plan twice
// to ensure the correct results for occurrences of default string type.
val finalPlan = apply0(newPlan)
RuleExecutor.forceAdditionalIteration(finalPlan)
finalPlan
} else {
newPlan
}
}

private def apply0(plan: LogicalPlan): LogicalPlan = {
if (isDDLCommand(plan)) {
transformDDL(plan)
} else {
transformPlan(plan, sessionDefaultStringType)
}
}

/**
* Returns whether any of the given `plan` needs to have its
* default string type resolved.
*/
def needsResolution(plan: LogicalPlan): Boolean = {
if (!isDDLCommand(plan) && isDefaultSessionCollationUsed) {
return false
}

plan.exists(node => needsResolution(node.expressions))
}

/**
* Returns whether any of the given `expressions` needs to have its
* default string type resolved.
*/
def needsResolution(expressions: Seq[Expression]): Boolean = {
expressions.exists(needsResolution)
}

/**
* Returns whether the given `expression` needs to have its
* default string type resolved.
*/
def needsResolution(expression: Expression): Boolean = {
expression.exists(e => transformExpression.isDefinedAt(e))
}

private def isDefaultSessionCollationUsed: Boolean = conf.defaultStringType == StringType

/**
* Returns the default string type that should be used in a given DDL command (for now always
* UTF8_BINARY).
*/
private def stringTypeForDDLCommand(table: LogicalPlan): StringType =
StringType("UTF8_BINARY")

/** Returns the session default string type */
private def sessionDefaultStringType: StringType =
StringType(conf.defaultStringType.collationId)

private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists {
case _: AddColumns | _: ReplaceColumns | _: AlterColumn => true
case _ => isCreateOrAlterPlan(plan)
}

private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match {
case _: V1CreateTablePlan | _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true
case _ => false
}

private def transformDDL(plan: LogicalPlan): LogicalPlan = {
val newType = stringTypeForDDLCommand(plan)

plan resolveOperators {
case p if isCreateOrAlterPlan(p) =>
transformPlan(p, newType)

case addCols: AddColumns =>
addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, newType))

case replaceCols: ReplaceColumns =>
replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType))

case alter: AlterColumn
if alter.dataType.isDefined && hasDefaultStringType(alter.dataType.get) =>
alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, newType)))
}
}

/**
* Transforms the given plan, by transforming all expressions in its operators to use the given
* new type instead of the default string type.
*/
private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = {
plan resolveExpressionsUp { expression =>
transformExpression
.andThen(_.apply(newType))
.applyOrElse(expression, identity[Expression])
}
}

/**
* Transforms the given expression, by changing all default string types to the given new type.
*/
private def transformExpression: PartialFunction[Expression, StringType => Expression] = {
case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) =>
newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType))

case cast: Cast if hasDefaultStringType(cast.dataType) =>
newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType))

case Literal(value, dt) if hasDefaultStringType(dt) =>
newType => Literal(value, replaceDefaultStringType(dt, newType))
}

private def hasDefaultStringType(dataType: DataType): Boolean =
dataType.existsRecursively(isDefaultStringType)

private def isDefaultStringType(dataType: DataType): Boolean = {
dataType match {
case st: StringType =>
// should only return true for StringType object and not StringType("UTF8_BINARY")
st.eq(StringType) || st.isInstanceOf[TemporaryStringType]
case _ => false
}
}

private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = {
dataType.transformRecursively {
case currentType: StringType if isDefaultStringType(currentType) =>
if (currentType == newType) {
TemporaryStringType()
} else {
newType
}
}
}

private def replaceColumnTypes(
colTypes: Seq[QualifiedColType],
newType: StringType): Seq[QualifiedColType] = {
colTypes.map {
case colWithDefault if hasDefaultStringType(colWithDefault.dataType) =>
val replaced = replaceDefaultStringType(colWithDefault.dataType, newType)
colWithDefault.copy(dataType = replaced)

case col => col
}
}
}

case class TemporaryStringType() extends StringType(1) {
override def toString: String = s"TemporaryStringType($collationId)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
object ResolveInlineTables extends Rule[LogicalPlan] with EvalHelper {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) {
case table: UnresolvedInlineTable if table.expressionsResolved =>
case table: UnresolvedInlineTable if canResolveTable(table) =>
EvaluateUnresolvedInlineTable.evaluateUnresolvedInlineTable(table)
}
}

private def canResolveTable(table: UnresolvedInlineTable): Boolean = {
table.expressionsResolved && !ResolveDefaultStringTypes.needsResolution(table)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ abstract class TypeCoercionHelper {
if conf.concatBinaryAsString ||
!children.map(_.dataType).forall(_ == BinaryType) =>
val newChildren = c.children.map { e =>
implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e)
implicitCast(e, StringType).getOrElse(e)
}
c.copy(children = newChildren)
case other => other
Expand Down Expand Up @@ -465,7 +465,7 @@ abstract class TypeCoercionHelper {
if (conf.eltOutputAsString ||
!children.tail.map(_.dataType).forall(_ == BinaryType)) {
children.tail.map { e =>
implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e)
implicitCast(e, StringType).getOrElse(e)
}
} else {
children.tail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType:
override def foldable: Boolean = false
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation, MapType(StringType, StringType))
Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation))

override def left: Expression = errorClass
override def right: Expression = errorParms
Expand Down
Loading

0 comments on commit b45045e

Please sign in to comment.