Skip to content

Commit

Permalink
[SPARK-49574][CONNECT][SQL] ExpressionEncoder tracks the AgnosticEnco…
Browse files Browse the repository at this point in the history
…der that created it

### What changes were proposed in this pull request?
This PR makes ExpressionEncoder track the AgnosticEncoder it is created from. The main reason for this change is to allow for situations where both Agnostic and ExpressionEncoders are used together.

### Why are the changes needed?
This is the first step in creating an shared Encoders object for Classic and Connect.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48062 from hvanhovell/SPARK-49574.

Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
hvanhovell committed Sep 11, 2024
1 parent e63b560 commit 14de06e
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.encoders
import java.{sql => jsql}
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.util.concurrent.ConcurrentHashMap

import scala.reflect.{classTag, ClassTag}

import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.SparkClassUtils
Expand Down Expand Up @@ -115,16 +115,24 @@ object AgnosticEncoders {
extends StructEncoder[K]

object ProductEncoder {
val cachedCls = new ConcurrentHashMap[Int, Class[_]]
private[sql] def tuple(encoders: Seq[AgnosticEncoder[_]]): AgnosticEncoder[_] = {
private val MAX_TUPLE_ELEMENTS = 22

private val tupleClassTags = Array.tabulate[ClassTag[Any]](MAX_TUPLE_ELEMENTS + 1) {
case 0 => null
case i => ClassTag(SparkClassUtils.classForName(s"scala.Tuple$i"))
}

private[sql] def tuple(
encoders: Seq[AgnosticEncoder[_]],
elementsCanBeNull: Boolean = false): AgnosticEncoder[_] = {
val numElements = encoders.size
if (numElements < 1 || numElements > MAX_TUPLE_ELEMENTS) {
throw ExecutionErrors.elementsOfTupleExceedLimitError()
}
val fields = encoders.zipWithIndex.map { case (e, id) =>
EncoderField(s"_${id + 1}", e, e.nullable, Metadata.empty)
EncoderField(s"_${id + 1}", e, e.nullable || elementsCanBeNull, Metadata.empty)
}
val cls = cachedCls.computeIfAbsent(
encoders.size,
_ =>
SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}"))
ProductEncoder[Any](ClassTag(cls), fields, None)
ProductEncoder[Any](tupleClassTags(numElements), fields, None)
}

private[sql] def isTuple(tag: ClassTag[_]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils}
* @tparam O
* output type (typically the internal representation of the data.
*/
trait Codec[I, O] {
trait Codec[I, O] extends Serializable {
def encode(in: I): O
def decode(out: O): I
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
def cannotUseKryoSerialization(): SparkRuntimeException = {
new SparkRuntimeException(errorClass = "CANNOT_USE_KRYO", messageParameters = Map.empty)
}

def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150")
}
}

private[sql] object ExecutionErrors extends ExecutionErrors
25 changes: 7 additions & 18 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ import java.lang.reflect.Modifier
import scala.reflect.{classTag, ClassTag}
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast}
import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Codec, ExpressionEncoder, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -193,7 +191,7 @@ object Encoders {
*
* @since 1.6.0
*/
def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec)

/**
* Creates an encoder that serializes objects of type T using Kryo.
Expand All @@ -215,7 +213,7 @@ object Encoders {
*
* @since 1.6.0
*/
def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(JavaSerializationCodec)

/**
* Creates an encoder that serializes objects of type T using generic Java serialization.
Expand All @@ -237,24 +235,15 @@ object Encoders {
}

/** A way to construct encoders using generic serializers. */
private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
private def genericSerializer[T: ClassTag](
provider: () => Codec[Any, Array[Byte]]): Encoder[T] = {
if (classTag[T].runtimeClass.isPrimitive) {
throw QueryExecutionErrors.primitiveTypesNotSupportedError()
}

validatePublicClass[T]()

ExpressionEncoder[T](
objSerializer =
EncodeUsingSerializer(
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo),
objDeserializer =
DecodeUsingSerializer[T](
Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
classTag[T],
kryo = useKryo),
clsTag = classTag[T]
)
ExpressionEncoder(TransformingEncoder(classTag[T], BinaryEncoder, provider))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -411,6 +411,12 @@ object DeserializerBuildHelper {
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)

case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = false)

case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = true)

case TransformingEncoder(tag, encoder, provider) =>
Invoke(
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.language.existentials

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
Expand Down Expand Up @@ -398,6 +398,12 @@ object SerializerBuildHelper {
}
createSerializerForObject(input, serializedFields)

case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec =>
EncodeUsingSerializer(input, kryo = false)

case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec =>
EncodeUsingSerializer(input, kryo = true)

case TransformingEncoder(_, encoder, codecProvider) =>
val encoded = Invoke(
Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, Java
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts}
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -54,9 +54,9 @@ object ExpressionEncoder {

def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = {
new ExpressionEncoder[T](
enc,
SerializerBuildHelper.createSerializer(enc),
DeserializerBuildHelper.createDeserializer(enc),
enc.clsTag)
DeserializerBuildHelper.createDeserializer(enc))
}

def apply(schema: StructType): ExpressionEncoder[Row] = apply(schema, lenient = false)
Expand All @@ -82,63 +82,10 @@ object ExpressionEncoder {
def tuple(
encoders: Seq[ExpressionEncoder[_]],
useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
if (encoders.length > 22) {
throw QueryExecutionErrors.elementsOfTupleExceedLimitError()
}

encoders.foreach(_.assertUnresolved())

val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true)
val serializers = encoders.zipWithIndex.map { case (enc, index) =>
val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct
assert(boundRefs.size == 1, "object serializer should have only one bound reference but " +
s"there are ${boundRefs.size}")

val originalInputObject = boundRefs.head
val newInputObject = Invoke(
newSerializerInput,
s"_${index + 1}",
originalInputObject.dataType,
returnNullable = originalInputObject.nullable)

val newSerializer = enc.objSerializer.transformUp {
case BoundReference(0, _, _) => newInputObject
}

Alias(newSerializer, s"_${index + 1}")()
}
val newSerializer = CreateStruct(serializers)

def nullSafe(input: Expression, result: Expression): Expression = {
If(IsNull(input), Literal.create(null, result.dataType), result)
}

val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType)
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct
assert(getColExprs.size == 1, "object deserializer should have only one " +
s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")

val input = GetStructField(newDeserializerInput, index)
val childDeserializer = enc.objDeserializer.transformUp {
case GetColumnByOrdinal(0, _) => input
}

if (useNullSafeDeserializer && enc.objSerializer.nullable) {
nullSafe(input, childDeserializer)
} else {
childDeserializer
}
}
val newDeserializer =
NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)

new ExpressionEncoder[Any](
nullSafe(newSerializerInput, newSerializer),
nullSafe(newDeserializerInput, newDeserializer),
ClassTag(cls))
val tupleEncoder = AgnosticEncoders.ProductEncoder.tuple(
encoders.map(_.encoder),
useNullSafeDeserializer)
ExpressionEncoder(tupleEncoder)
}

// Tuple1
Expand Down Expand Up @@ -228,6 +175,7 @@ object ExpressionEncoder {
* A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer`
* and a `deserializer`.
*
* @param encoder the `AgnosticEncoder` for type `T`.
* @param objSerializer An expression that can be used to encode a raw object to corresponding
* Spark SQL representation that can be a primitive column, array, map or a
* struct. This represents how Spark SQL generally serializes an object of
Expand All @@ -236,14 +184,15 @@ object ExpressionEncoder {
* representation. This represents how Spark SQL generally deserializes
* a serialized value in Spark SQL representation back to an object of
* type `T`.
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
encoder: AgnosticEncoder[T],
objSerializer: Expression,
objDeserializer: Expression,
clsTag: ClassTag[T])
objDeserializer: Expression)
extends Encoder[T] {

override def clsTag: ClassTag[T] = encoder.clsTag

/**
* A sequence of expressions, one for each top-level field that can be used to
* extract the values from a raw object into an [[InternalRow]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.SerializerSupport
/**
* A codec that uses Kryo to (de)serialize arbitrary objects to and from a byte array.
*/
class KryoSerializationCodecImpl extends Codec [Any, Array[Byte]] {
class KryoSerializationCodecImpl extends Codec[Any, Array[Byte]] {
private val serializer = SerializerSupport.newSerializer(useKryo = true)
override def encode(in: Any): Array[Byte] =
serializer.serialize(in).array()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1307,10 +1307,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
messageParameters = Map("badRecord" -> badRecord))
}

def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150")
}

def expressionDecodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "EXPRESSION_DECODING_FAILED",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
val encoder = ExpressionEncoder(schema, lenient = true)
val unexpectedSerializer = NaNvl(encoder.objSerializer, encoder.objSerializer)
val exception = intercept[org.apache.spark.SparkRuntimeException] {
new ExpressionEncoder[Row](unexpectedSerializer, encoder.objDeserializer, encoder.clsTag)
new ExpressionEncoder[Row](encoder.encoder, unexpectedSerializer, encoder.objDeserializer)
}
checkError(
exception = exception,
Expand Down
12 changes: 8 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1187,11 +1187,15 @@ class DatasetSuite extends QueryTest
exception = intercept[AnalysisException] {
df.as[KryoData]
},
condition = "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
condition = "CANNOT_UP_CAST_DATATYPE",
parameters = Map(
"sqlExpr" -> "\"a\"",
"srcType" -> "\"DOUBLE\"",
"targetType" -> "\"BINARY\""))
"expression" -> "a",
"sourceType" -> "\"DOUBLE\"",
"targetType" -> "\"BINARY\"",
"details" -> ("The type path of the target object is:\n- root class: " +
"\"org.apache.spark.sql.KryoData\"\n" +
"You can either add an explicit cast to the input data or choose a " +
"higher precision type of the field in the target object")))
}

test("Java encoder") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ class QueryExecutionErrorsSuite
val enc: ExpressionEncoder[Row] = ExpressionEncoder(rowEnc)
val deserializer = AttributeReference.apply("v", IntegerType)()
implicit val im: ExpressionEncoder[Row] = new ExpressionEncoder[Row](
enc.objSerializer, deserializer, enc.clsTag)
rowEnc, enc.objSerializer, deserializer)

checkError(
exception = intercept[SparkRuntimeException] {
Expand Down

0 comments on commit 14de06e

Please sign in to comment.