Skip to content

Commit

Permalink
[SPARK-49962][SQL] Simplify AbstractStringTypes class hierarchy
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Simplifying the AbstractStringType hierarchy.

### Why are the changes needed?
The addition of trim-sensitive collation (#48336) highlighted the complexity of extending the existing AbstractStringType structure. Besides adding a new parameter to all types inheriting from AbstractStringType, it caused changing the logic of every subclass as well as changing the name of a derived class StringTypeAnyCollation into StringTypeWithCaseAccentSensitivity which could again be subject to change if we keep adding new specifiers.

Looking ahead, the introduction of support for indeterminate collation would further complicate these types. To address this, the proposed changes simplify the design by consolidating common logic into a single base class. This base class will handle core functionality such as trim or indeterminate collation, while a derived class, StringTypeWithCollation (previously awkwardly called StringTypeWithCaseAccentSensitivity), will manage collation specifiers.

This approach allows for easier future extensions: fundamental checks can be handled in the base class, while any new specifiers can be added as optional fields in StringTypeWithCollation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
With existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48459 from stefankandic/refactorStringTypes.

Authored-by: Stefan Kandic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
stefankandic authored and MaxGekk committed Oct 17, 2024
1 parent 6362e0c commit 91becf1
Show file tree
Hide file tree
Showing 26 changed files with 226 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1157,15 +1157,13 @@ public static int collationNameToId(String collationName) throws SparkException
return Collation.CollationSpec.collationNameToId(collationName);
}

/**
* Returns whether the ICU collation is not Case Sensitive Accent Insensitive
* for the given collation id.
* This method is used in expressions which do not support CS_AI collations.
*/
public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) {
public static boolean isCaseInsensitive(int collationId) {
return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity ==
Collation.CollationSpecICU.CaseSensitivity.CS &&
Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
Collation.CollationSpecICU.CaseSensitivity.CI;
}

public static boolean isAccentInsensitive(int collationId) {
return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
Collation.CollationSpecICU.AccentSensitivity.AI;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,34 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* AbstractStringType is an abstract class for StringType with collation support. As every type of
* collation can support trim specifier this class is parametrized with it.
* AbstractStringType is an abstract class for StringType with collation support.
*/
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
abstract class AbstractStringType(supportsTrimCollation: Boolean = false)
extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def simpleString: String = "string"
private[sql] def canUseTrimCollation(other: DataType): Boolean =
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation

override private[sql] def acceptsType(other: DataType): Boolean = other match {
case st: StringType =>
canUseTrimCollation(st) && acceptsStringType(st)
case _ =>
false
}

private[sql] def canUseTrimCollation(other: StringType): Boolean =
supportsTrimCollation || !other.usesTrimCollation

def acceptsStringType(other: StringType): Boolean
}

/**
* Use StringTypeBinary for expressions supporting only binary collation.
* Used for expressions supporting only binary collation.
*/
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
case class StringTypeBinary(supportsTrimCollation: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean =
other.supportsBinaryEquality
}

object StringTypeBinary extends StringTypeBinary(false) {
Expand All @@ -49,13 +58,13 @@ object StringTypeBinary extends StringTypeBinary(false) {
}

/**
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
* Used for expressions supporting only binary and lowercase collation.
*/
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
case class StringTypeBinaryLcase(supportsTrimCollation: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean =
other.supportsBinaryEquality || other.isUTF8LcaseCollation
}

object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
Expand All @@ -65,31 +74,44 @@ object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
}

/**
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
* and ICU) but limited to using case and accent sensitivity specifiers.
* Used for expressions supporting collation types with optional case, accent, and trim
* sensitivity specifiers.
*
* Case and accent sensitivity specifiers are supported by default.
*/
case class StringTypeWithCaseAccentSensitivity(
override val supportsTrimCollation: Boolean = false)
case class StringTypeWithCollation(
supportsTrimCollation: Boolean,
supportsCaseSpecifier: Boolean,
supportsAccentSpecifier: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean = {
(supportsCaseSpecifier || !other.isCaseInsensitive) &&
(supportsAccentSpecifier || !other.isAccentInsensitive)
}
}

object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
object StringTypeWithCollation extends StringTypeWithCollation(false, true, true) {
def apply(
supportsTrimCollation: Boolean = false,
supportsCaseSpecifier: Boolean = true,
supportsAccentSpecifier: Boolean = true): StringTypeWithCollation = {
new StringTypeWithCollation(
supportsTrimCollation,
supportsCaseSpecifier,
supportsAccentSpecifier)
}
}

/**
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
* CS_AI collation types.
* Used for expressions supporting all possible collation types except those that are
* case-sensitive but accent insensitive (CS_AI).
*/
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
case class StringTypeNonCSAICollation(supportsTrimCollation: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean =
other.isCaseInsensitive || !other.isAccentInsensitive
}

object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
private[sql] def supportsLowercaseEquality: Boolean =
CollationFactory.fetchCollation(collationId).supportsLowercaseEquality

private[sql] def isNonCSAI: Boolean =
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)
private[sql] def isCaseInsensitive: Boolean =
CollationFactory.isCaseInsensitive(collationId)

private[sql] def isAccentInsensitive: Boolean =
CollationFactory.isAccentInsensitive(collationId)

private[sql] def usesTrimCollation: Boolean =
CollationFactory.fetchCollation(collationId).supportsSpaceTrimming
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType,
StringTypeWithCaseAccentSensitivity}
StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence

Expand Down Expand Up @@ -439,7 +439,7 @@ abstract class TypeCoercionBase {
}

case aj @ ArrayJoin(arr, d, nr)
if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) &&
if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
implicitCast(arr, ArrayType(StringType, containsNull)) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -84,7 +84,7 @@ case class CallMethodViaReflection(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("class"),
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
"inputType" -> toSQLType(StringTypeWithCollation),
"inputExpr" -> toSQLExpr(children.head)
)
)
Expand All @@ -97,7 +97,7 @@ case class CallMethodViaReflection(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("method"),
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
"inputType" -> toSQLType(StringTypeWithCollation),
"inputExpr" -> toSQLExpr(children(1))
)
)
Expand All @@ -115,7 +115,7 @@ case class CallMethodViaReflection(
"requiredType" -> toSQLType(
TypeCollection(BooleanType, ByteType, ShortType,
IntegerType, LongType, FloatType, DoubleType,
StringTypeWithCaseAccentSensitivity)),
StringTypeWithCollation)),
"inputSql" -> toSQLExpr(e),
"inputType" -> toSQLType(e.dataType))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
Seq(StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -61,7 +61,7 @@ object ExprUtils extends EvalHelper with QueryErrorsBase {

def convertToMapData(exp: Expression): Map[String, String] = exp match {
case m: CreateMap
if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity)
if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)
.acceptsType(m.dataType) =>
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -109,7 +109,7 @@ case class HllSketchAgg(
TypeCollection(
IntegerType,
LongType,
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
StringTypeWithCollation(supportsTrimCollation = true),
BinaryType),
IntegerType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._

// scalastyle:off line.contains.tab
Expand Down Expand Up @@ -78,7 +78,7 @@ case class Collate(child: Expression, collationName: String)
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
Seq(StringTypeWithCollation(supportsTrimCollation = true))

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down Expand Up @@ -117,5 +117,5 @@ case class Collation(child: Expression)
Literal.create(collationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
Seq(StringTypeWithCollation(supportsTrimCollation = true))
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
Expand Down Expand Up @@ -1349,7 +1349,7 @@ case class Reverse(child: Expression)

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType))
Seq(TypeCollection(StringTypeWithCollation, ArrayType))

override def dataType: DataType = child.dataType

Expand Down Expand Up @@ -2135,12 +2135,12 @@ case class ArrayJoin(
this(array, delimiter, Some(nullReplacement))

override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
StringTypeWithCaseAccentSensitivity,
StringTypeWithCaseAccentSensitivity)
Seq(AbstractArrayType(StringTypeWithCollation),
StringTypeWithCollation,
StringTypeWithCollation)
} else {
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
StringTypeWithCaseAccentSensitivity)
Seq(AbstractArrayType(StringTypeWithCollation),
StringTypeWithCollation)
}

override def children: Seq[Expression] = if (nullReplacement.isDefined) {
Expand Down Expand Up @@ -2861,7 +2861,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
with QueryErrorsBase {

private def allowedTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType)
Seq(StringTypeWithCollation, BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -147,7 +147,7 @@ case class CsvToStructs(
converter(parser.parse(csv))
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil

override def prettyName: String = "from_csv"

Expand Down
Loading

0 comments on commit 91becf1

Please sign in to comment.