From 14de06e19d38b038e17dd94d43e2cdc767d588e5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 11 Sep 2024 12:32:20 -0400 Subject: [PATCH] [SPARK-49574][CONNECT][SQL] ExpressionEncoder tracks the AgnosticEncoder that created it ### What changes were proposed in this pull request? This PR makes ExpressionEncoder track the AgnosticEncoder it is created from. The main reason for this change is to allow for situations where both Agnostic and ExpressionEncoders are used together. ### Why are the changes needed? This is the first step in creating an shared Encoders object for Classic and Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48062 from hvanhovell/SPARK-49574. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../catalyst/encoders/AgnosticEncoder.scala | 26 ++++--- .../spark/sql/catalyst/encoders/codecs.scala | 2 +- .../spark/sql/errors/ExecutionErrors.scala | 4 + .../scala/org/apache/spark/sql/Encoders.scala | 25 ++----- .../catalyst/DeserializerBuildHelper.scala | 10 ++- .../sql/catalyst/SerializerBuildHelper.scala | 8 +- .../catalyst/encoders/ExpressionEncoder.scala | 75 +++---------------- .../encoders/KryoSerializationCodecImpl.scala | 2 +- .../sql/errors/QueryExecutionErrors.scala | 4 - .../encoders/ExpressionEncoderSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 12 ++- .../errors/QueryExecutionErrorsSuite.scala | 2 +- 12 files changed, 67 insertions(+), 105 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index f1f2ea34323b4..a578495755492 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.encoders import java.{sql => jsql} import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} -import java.util.concurrent.ConcurrentHashMap import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.SparkClassUtils @@ -115,16 +115,24 @@ object AgnosticEncoders { extends StructEncoder[K] object ProductEncoder { - val cachedCls = new ConcurrentHashMap[Int, Class[_]] - private[sql] def tuple(encoders: Seq[AgnosticEncoder[_]]): AgnosticEncoder[_] = { + private val MAX_TUPLE_ELEMENTS = 22 + + private val tupleClassTags = Array.tabulate[ClassTag[Any]](MAX_TUPLE_ELEMENTS + 1) { + case 0 => null + case i => ClassTag(SparkClassUtils.classForName(s"scala.Tuple$i")) + } + + private[sql] def tuple( + encoders: Seq[AgnosticEncoder[_]], + elementsCanBeNull: Boolean = false): AgnosticEncoder[_] = { + val numElements = encoders.size + if (numElements < 1 || numElements > MAX_TUPLE_ELEMENTS) { + throw ExecutionErrors.elementsOfTupleExceedLimitError() + } val fields = encoders.zipWithIndex.map { case (e, id) => - EncoderField(s"_${id + 1}", e, e.nullable, Metadata.empty) + EncoderField(s"_${id + 1}", e, e.nullable || elementsCanBeNull, Metadata.empty) } - val cls = cachedCls.computeIfAbsent( - encoders.size, - _ => - SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")) - ProductEncoder[Any](ClassTag(cls), fields, None) + ProductEncoder[Any](tupleClassTags(numElements), fields, None) } private[sql] def isTuple(tag: ClassTag[_]): Boolean = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala index ceb615b13f99a..0f21972552339 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils} * @tparam O * output type (typically the internal representation of the data. */ -trait Codec[I, O] { +trait Codec[I, O] extends Serializable { def encode(in: I): O def decode(out: O): I } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index 5b761e9170572..4890ff4431fe6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -216,6 +216,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def cannotUseKryoSerialization(): SparkRuntimeException = { new SparkRuntimeException(errorClass = "CANNOT_USE_KRYO", messageParameters = Map.empty) } + + def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150") + } } private[sql] object ExecutionErrors extends ExecutionErrors diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 9b95f74db3a49..7e040f6232fbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -22,10 +22,8 @@ import java.lang.reflect.Modifier import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} -import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Codec, ExpressionEncoder, JavaSerializationCodec, KryoSerializationCodec} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -193,7 +191,7 @@ object Encoders { * * @since 1.6.0 */ - def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec) /** * Creates an encoder that serializes objects of type T using Kryo. @@ -215,7 +213,7 @@ object Encoders { * * @since 1.6.0 */ - def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(JavaSerializationCodec) /** * Creates an encoder that serializes objects of type T using generic Java serialization. @@ -237,24 +235,15 @@ object Encoders { } /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + private def genericSerializer[T: ClassTag]( + provider: () => Codec[Any, Array[Byte]]): Encoder[T] = { if (classTag[T].runtimeClass.isPrimitive) { throw QueryExecutionErrors.primitiveTypesNotSupportedError() } validatePublicClass[T]() - ExpressionEncoder[T]( - objSerializer = - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo), - objDeserializer = - DecodeUsingSerializer[T]( - Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), - classTag[T], - kryo = useKryo), - clsTag = classTag[T] - ) + ExpressionEncoder(TransformingEncoder(classTag[T], BinaryEncoder, provider)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 40b49506b58aa..4752434015375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.types._ @@ -411,6 +411,12 @@ object DeserializerBuildHelper { val result = InitializeJavaBean(newInstance, setters.toMap) exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) + case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec => + DecodeUsingSerializer(path, tag, kryo = false) + + case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec => + DecodeUsingSerializer(path, tag, kryo = true) + case TransformingEncoder(tag, encoder, provider) => Invoke( Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 38bf0651d6f1c..daebe15c298f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -21,7 +21,7 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} @@ -398,6 +398,12 @@ object SerializerBuildHelper { } createSerializerForObject(input, serializedFields) + case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec => + EncodeUsingSerializer(input, kryo = false) + + case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec => + EncodeUsingSerializer(input, kryo = true) + case TransformingEncoder(_, encoder, codecProvider) => val encoded = Invoke( Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0b5ce65fed6df..8e39ae0389c2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, Java import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, NewInstance} import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts} import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -54,9 +54,9 @@ object ExpressionEncoder { def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = { new ExpressionEncoder[T]( + enc, SerializerBuildHelper.createSerializer(enc), - DeserializerBuildHelper.createDeserializer(enc), - enc.clsTag) + DeserializerBuildHelper.createDeserializer(enc)) } def apply(schema: StructType): ExpressionEncoder[Row] = apply(schema, lenient = false) @@ -82,63 +82,10 @@ object ExpressionEncoder { def tuple( encoders: Seq[ExpressionEncoder[_]], useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = { - if (encoders.length > 22) { - throw QueryExecutionErrors.elementsOfTupleExceedLimitError() - } - - encoders.foreach(_.assertUnresolved()) - - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - - val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true) - val serializers = encoders.zipWithIndex.map { case (enc, index) => - val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct - assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + - s"there are ${boundRefs.size}") - - val originalInputObject = boundRefs.head - val newInputObject = Invoke( - newSerializerInput, - s"_${index + 1}", - originalInputObject.dataType, - returnNullable = originalInputObject.nullable) - - val newSerializer = enc.objSerializer.transformUp { - case BoundReference(0, _, _) => newInputObject - } - - Alias(newSerializer, s"_${index + 1}")() - } - val newSerializer = CreateStruct(serializers) - - def nullSafe(input: Expression, result: Expression): Expression = { - If(IsNull(input), Literal.create(null, result.dataType), result) - } - - val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) - val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct - assert(getColExprs.size == 1, "object deserializer should have only one " + - s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") - - val input = GetStructField(newDeserializerInput, index) - val childDeserializer = enc.objDeserializer.transformUp { - case GetColumnByOrdinal(0, _) => input - } - - if (useNullSafeDeserializer && enc.objSerializer.nullable) { - nullSafe(input, childDeserializer) - } else { - childDeserializer - } - } - val newDeserializer = - NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) - - new ExpressionEncoder[Any]( - nullSafe(newSerializerInput, newSerializer), - nullSafe(newDeserializerInput, newDeserializer), - ClassTag(cls)) + val tupleEncoder = AgnosticEncoders.ProductEncoder.tuple( + encoders.map(_.encoder), + useNullSafeDeserializer) + ExpressionEncoder(tupleEncoder) } // Tuple1 @@ -228,6 +175,7 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * + * @param encoder the `AgnosticEncoder` for type `T`. * @param objSerializer An expression that can be used to encode a raw object to corresponding * Spark SQL representation that can be a primitive column, array, map or a * struct. This represents how Spark SQL generally serializes an object of @@ -236,14 +184,15 @@ object ExpressionEncoder { * representation. This represents how Spark SQL generally deserializes * a serialized value in Spark SQL representation back to an object of * type `T`. - * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( + encoder: AgnosticEncoder[T], objSerializer: Expression, - objDeserializer: Expression, - clsTag: ClassTag[T]) + objDeserializer: Expression) extends Encoder[T] { + override def clsTag: ClassTag[T] = encoder.clsTag + /** * A sequence of expressions, one for each top-level field that can be used to * extract the values from a raw object into an [[InternalRow]]: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala index 5e46e7245c05e..49c7b41f77472 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.SerializerSupport /** * A codec that uses Kryo to (de)serialize arbitrary objects to and from a byte array. */ -class KryoSerializationCodecImpl extends Codec [Any, Array[Byte]] { +class KryoSerializationCodecImpl extends Codec[Any, Array[Byte]] { private val serializer = SerializerSupport.newSerializer(useKryo = true) override def encode(in: Any): Array[Byte] = serializer.serialize(in).array() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ad8437ed7a50d..0b37cf951a29b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1307,10 +1307,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("badRecord" -> badRecord)) } - def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150") - } - def expressionDecodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = { new SparkRuntimeException( errorClass = "EXPRESSION_DECODING_FAILED", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 879b4ef6d3745..0c0c7f12f1764 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -531,7 +531,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes val encoder = ExpressionEncoder(schema, lenient = true) val unexpectedSerializer = NaNvl(encoder.objSerializer, encoder.objSerializer) val exception = intercept[org.apache.spark.SparkRuntimeException] { - new ExpressionEncoder[Row](unexpectedSerializer, encoder.objDeserializer, encoder.clsTag) + new ExpressionEncoder[Row](encoder.encoder, unexpectedSerializer, encoder.objDeserializer) } checkError( exception = exception, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 84df437305966..089ce79201dd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1187,11 +1187,15 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { df.as[KryoData] }, - condition = "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map( - "sqlExpr" -> "\"a\"", - "srcType" -> "\"DOUBLE\"", - "targetType" -> "\"BINARY\"")) + "expression" -> "a", + "sourceType" -> "\"DOUBLE\"", + "targetType" -> "\"BINARY\"", + "details" -> ("The type path of the target object is:\n- root class: " + + "\"org.apache.spark.sql.KryoData\"\n" + + "You can either add an explicit cast to the input data or choose a " + + "higher precision type of the field in the target object"))) } test("Java encoder") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index da1366350d03e..00dfd3451d577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -1176,7 +1176,7 @@ class QueryExecutionErrorsSuite val enc: ExpressionEncoder[Row] = ExpressionEncoder(rowEnc) val deserializer = AttributeReference.apply("v", IntegerType)() implicit val im: ExpressionEncoder[Row] = new ExpressionEncoder[Row]( - enc.objSerializer, deserializer, enc.clsTag) + rowEnc, enc.objSerializer, deserializer) checkError( exception = intercept[SparkRuntimeException] {