Skip to content

Commit

Permalink
[SPARK-40370][SQL] Migrate type check fails to error classes in CAST
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In the PR, I propose to use error classes in the case of type check failure in the `CAST` expression.

### Why are the changes needed?
Migration onto error classes unifies Spark SQL error messages, and improves search-ability of errors.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By running the modified test suites:
```
$ build/sbt "test:testOnly *CastWithAnsiOnSuite"
$ build/sbt "test:testOnly *DatasetSuite"
$ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql"
```

Closes apache#37869 from MaxGekk/datatype-mismatch-in-cast.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
MaxGekk committed Sep 15, 2022
1 parent 034e48f commit 6d067d0
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 97 deletions.
17 changes: 17 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@
"message" : [
"the binary operator requires the input type <inputType>, not <actualDataType>."
]
},
"CAST_WITHOUT_SUGGESTION" : {
"message" : [
"cannot cast <srcType> to <targetType>."
]
},
"CAST_WITH_CONF_SUGGESTION" : {
"message" : [
"cannot cast <srcType> to <targetType> with ANSI mode on.",
"If you have to cast <srcType> to <targetType>, you can set <config> as <configVal>."
]
},
"CAST_WITH_FUN_SUGGESTION" : {
"message" : [
"cannot cast <srcType> to <targetType>.",
"To convert values from <srcType> to <targetType>, you can use the functions <functionNames> instead."
]
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ private[spark] object SparkThrowableHelper {
val messageParameters = e.getMessageParameters
if (!messageParameters.isEmpty) {
g.writeObjectFieldStart("messageParameters")
messageParameters.asScala.toSeq.sortBy(_._1).foreach { case (name, value) =>
g.writeStringField(name, value)
}
messageParameters.asScala
.toMap // To remove duplicates
.toSeq.sortBy(_._1)
.foreach { case (name, value) => g.writeStringField(name, value) }
g.writeEndObject()
}
val queryContext = e.getQueryContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkArithmeticException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag}
Expand All @@ -33,14 +34,14 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}

object Cast {
object Cast extends QueryErrorsBase {
/**
* As per section 6.13 "cast specification" in "Information technology — Database languages " +
* "- SQL — Part 2: Foundation (SQL/Foundation)":
Expand Down Expand Up @@ -412,47 +413,48 @@ object Cast {
}
}

// Show suggestion on how to complete the disallowed explicit casting with built-in type
// conversion functions.
private def suggestionOnConversionFunctions (
from: DataType,
to: DataType,
functionNames: String): String = {
// scalastyle:off line.size.limit
s"""cannot cast ${from.catalogString} to ${to.catalogString}.
|To convert values from ${from.catalogString} to ${to.catalogString}, you can use $functionNames instead.
|""".stripMargin
// scalastyle:on line.size.limit
}

def typeCheckFailureMessage(
from: DataType,
to: DataType,
fallbackConf: Option[(String, String)]): String =
fallbackConf: Option[(String, String)]): DataTypeMismatch = {
def withFunSuggest(names: String*): DataTypeMismatch = {
DataTypeMismatch(
errorSubClass = "CAST_WITH_FUN_SUGGESTION",
messageParameters = Map(
"srcType" -> toSQLType(from),
"targetType" -> toSQLType(to),
"functionNames" -> names.map(toSQLId).mkString("/")))
}
(from, to) match {
case (_: NumericType, TimestampType) =>
suggestionOnConversionFunctions(from, to,
"functions TIMESTAMP_SECONDS/TIMESTAMP_MILLIS/TIMESTAMP_MICROS")
withFunSuggest("TIMESTAMP_SECONDS", "TIMESTAMP_MILLIS", "TIMESTAMP_MICROS")

case (TimestampType, _: NumericType) =>
suggestionOnConversionFunctions(from, to, "functions UNIX_SECONDS/UNIX_MILLIS/UNIX_MICROS")
withFunSuggest("UNIX_SECONDS", "UNIX_MILLIS", "UNIX_MICROS")

case (_: NumericType, DateType) =>
suggestionOnConversionFunctions(from, to, "function DATE_FROM_UNIX_DATE")
withFunSuggest("DATE_FROM_UNIX_DATE")

case (DateType, _: NumericType) =>
suggestionOnConversionFunctions(from, to, "function UNIX_DATE")
withFunSuggest("UNIX_DATE")

// scalastyle:off line.size.limit
case _ if fallbackConf.isDefined && Cast.canCast(from, to) =>
s"""
| cannot cast ${from.catalogString} to ${to.catalogString} with ANSI mode on.
| If you have to cast ${from.catalogString} to ${to.catalogString}, you can set ${fallbackConf.get._1} as ${fallbackConf.get._2}.
|""".stripMargin
// scalastyle:on line.size.limit
DataTypeMismatch(
errorSubClass = "CAST_WITH_CONF_SUGGESTION",
messageParameters = Map(
"srcType" -> toSQLType(from),
"targetType" -> toSQLType(to),
"config" -> toSQLConf(fallbackConf.get._1),
"configVal" -> toSQLValue(fallbackConf.get._2, StringType)))

case _ => s"cannot cast ${from.catalogString} to ${to.catalogString}"
case _ =>
DataTypeMismatch(
errorSubClass = "CAST_WITHOUT_SUGGESTION",
messageParameters = Map(
"srcType" -> toSQLType(from),
"targetType" -> toSQLType(to)))
}
}

def apply(
child: Expression,
Expand Down Expand Up @@ -487,8 +489,12 @@ case class Cast(
child: Expression,
dataType: DataType,
timeZoneId: Option[String] = None,
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends UnaryExpression
with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get))
extends UnaryExpression
with TimeZoneAwareExpression
with NullIntolerant
with SupportQueryContext
with QueryErrorsBase {

def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
this(child, dataType, timeZoneId, evalMode = EvalMode.fromSQLConf(SQLConf.get))
Expand All @@ -509,7 +515,7 @@ case class Cast(
evalMode == EvalMode.TRY
}

private def typeCheckFailureMessage: String = evalMode match {
private def typeCheckFailureInCast: DataTypeMismatch = evalMode match {
case EvalMode.ANSI =>
if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Expand All @@ -522,7 +528,11 @@ case class Cast(
case EvalMode.TRY =>
Cast.typeCheckFailureMessage(child.dataType, dataType, None)
case _ =>
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
DataTypeMismatch(
errorSubClass = "CAST_WITHOUT_SUGGESTION",
messageParameters = Map(
"srcType" -> toSQLType(child.dataType),
"targetType" -> toSQLType(dataType)))
}

override def checkInputDataTypes(): TypeCheckResult = {
Expand All @@ -535,7 +545,7 @@ case class Cast(
if (canCast) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
typeCheckFailureInCast
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.collection.parallel.immutable.ParVector
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -66,21 +66,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null)
}

protected def verifyCastFailure(c: Cast, optionalExpectedMsg: Option[String] = None): Unit = {
protected def verifyCastFailure(c: Cast, expected: DataTypeMismatch): Unit = {
val typeCheckResult = c.checkInputDataTypes()
assert(typeCheckResult.isFailure)
assert(typeCheckResult.isInstanceOf[TypeCheckFailure])
val message = typeCheckResult.asInstanceOf[TypeCheckFailure].message

if (optionalExpectedMsg.isDefined) {
assert(message.contains(optionalExpectedMsg.get))
} else {
assert("cannot cast [a-zA-Z]+ to [a-zA-Z]+".r.findFirstIn(message).isDefined)
if (evalMode == EvalMode.ANSI) {
assert(message.contains("with ANSI mode on"))
assert(message.contains(s"set ${SQLConf.ANSI_ENABLED.key} as false"))
}
}
assert(typeCheckResult.isInstanceOf[DataTypeMismatch])
val mismatch = typeCheckResult.asInstanceOf[DataTypeMismatch]
assert(mismatch === expected)
}

test("null cast") {
Expand Down Expand Up @@ -936,13 +927,19 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
test("disallow type conversions between Numeric types and Timestamp without time zone type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(TimestampNTZType)
var errorMsg = "cannot cast bigint to timestamp_ntz"
verifyCastFailure(cast(Literal(0L), TimestampNTZType), Some(errorMsg))
verifyCastFailure(
cast(Literal(0L), TimestampNTZType),
DataTypeMismatch(
"CAST_WITHOUT_SUGGESTION",
Map("srcType" -> "\"BIGINT\"", "targetType" -> "\"TIMESTAMP_NTZ\"")))

val timestampNTZLiteral = Literal.create(LocalDateTime.now(), TimestampNTZType)
errorMsg = "cannot cast timestamp_ntz to"
numericTypes.foreach { numericType =>
verifyCastFailure(cast(timestampNTZLiteral, numericType), Some(errorMsg))
verifyCastFailure(
cast(timestampNTZLiteral, numericType),
DataTypeMismatch(
"CAST_WITHOUT_SUGGESTION",
Map("srcType" -> "\"TIMESTAMP_NTZ\"", "targetType" -> s""""${numericType.sql}"""")))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.time.DateTimeException
import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
Expand Down Expand Up @@ -141,12 +142,26 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
test("ANSI mode: disallow type conversions between Numeric types and Date type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(DateType)
var errorMsg = "you can use function DATE_FROM_UNIX_DATE instead"
verifyCastFailure(cast(Literal(0L), DateType), Some(errorMsg))
verifyCastFailure(
cast(Literal(0L), DateType),
DataTypeMismatch(
"CAST_WITH_FUN_SUGGESTION",
Map(
"srcType" -> "\"BIGINT\"",
"targetType" -> "\"DATE\"",
"functionNames" -> "`DATE_FROM_UNIX_DATE`")))
val dateLiteral = Literal(1, DateType)
errorMsg = "you can use function UNIX_DATE instead"
numericTypes.foreach { numericType =>
verifyCastFailure(cast(dateLiteral, numericType), Some(errorMsg))
withClue(s"numericType = ${numericType.sql}") {
verifyCastFailure(
cast(dateLiteral, numericType),
DataTypeMismatch(
"CAST_WITH_FUN_SUGGESTION",
Map(
"srcType" -> "\"DATE\"",
"targetType" -> s""""${numericType.sql}"""",
"functionNames" -> "`UNIX_DATE`")))
}
}
}

Expand Down
Loading

0 comments on commit 6d067d0

Please sign in to comment.