From 2c88a28f7a5c3c769963e3aa997490d5aba0d41b Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Thu, 26 Oct 2023 02:29:23 -0400 Subject: [PATCH] Add support for enums (#52) --- .../scala/scala3encoders/EncoderDerivation.scala | 2 +- .../scala3encoders/derivation/Deserializer.scala | 15 ++++++++++++++- .../scala3encoders/derivation/Serializer.scala | 13 ++++++++++++- .../test/scala/sql/EncoderDerivationSpec.scala | 14 ++++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala b/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala index 32240a3..3cb5ed1 100644 --- a/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala +++ b/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala @@ -13,7 +13,7 @@ given encoder[T](using deserializer: Deserializer[T], classTag: ClassTag[T] ): ExpressionEncoder[T] = - val inputObject = BoundReference(0, serializer.inputType, true) + val inputObject = BoundReference(0, serializer.inputType, nullable = true) val path = GetColumnByOrdinal(0, deserializer.inputType) ExpressionEncoder( diff --git a/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala b/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala index 3a53461..3f2a9a5 100644 --- a/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala +++ b/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala @@ -2,7 +2,7 @@ package scala3encoders.derivation import scala.compiletime.{constValue, summonInline, erasedValue} import scala.deriving.Mirror -import scala.reflect.ClassTag +import scala.reflect.{ClassTag, Enum} import org.apache.spark.sql.catalyst.expressions.{ Expression, @@ -144,6 +144,19 @@ object Deserializer: def deserialize(path: Expression): Expression = createDeserializerForScalaBigInt(path) + given[E <: Enum : ClassTag]: Deserializer[E] with + def inputType: DataType = StringType + + def deserialize(path: Expression): Expression = + val string = summon[Deserializer[String]].deserialize(path) + StaticInvoke( + summon[ClassTag[E]].runtimeClass, + ObjectType(summon[ClassTag[E]].runtimeClass), + "valueOf", + string :: Nil, + returnNullable = false + ) + inline given deriveArray[T](using d: Deserializer[T], ct: ClassTag[T] diff --git a/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala b/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala index 14bcb4c..47196de 100644 --- a/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala +++ b/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala @@ -2,7 +2,7 @@ package scala3encoders.derivation import scala.compiletime.{constValue, summonInline, erasedValue} import scala.deriving.Mirror -import scala.reflect.ClassTag +import scala.reflect.{ClassTag, Enum} import org.apache.spark.sql.catalyst.expressions.{Expression, KnownNotNull} import org.apache.spark.sql.catalyst.expressions.objects.Invoke @@ -118,6 +118,17 @@ object Serializer: def inputType: DataType = ObjectType(classOf[String]) def serialize(inputObject: Expression): Expression = createSerializerForString(inputObject) + + given [E <: Enum: ClassTag]: Serializer[E] with + def inputType: DataType = ObjectType(summon[ClassTag[E]].runtimeClass) + def serialize(inputObject: Expression): Expression = + val string = Invoke( + inputObject, + "toString", + ObjectType(classOf[String]), + returnNullable = false + ) + summon[Serializer[String]].serialize(string) given deriveSeq[F[_], T](using s: Serializer[T])(using F[T] <:< Seq[T] diff --git a/encoders/src/test/scala/sql/EncoderDerivationSpec.scala b/encoders/src/test/scala/sql/EncoderDerivationSpec.scala index 840d408..806251f 100644 --- a/encoders/src/test/scala/sql/EncoderDerivationSpec.scala +++ b/encoders/src/test/scala/sql/EncoderDerivationSpec.scala @@ -134,6 +134,10 @@ case class City(name: String, lat: Double, lon: Double) case class CityWithInts(name: String, lat: Int, lon: Int) case class Journey(id: Int, cities: Seq[City]) +enum Color: + case Red, Black +case class ColorData(color: Color) + val dSchema = StructType( Seq( @@ -275,6 +279,16 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: assertEquals(input.toDS.collect.toSeq, input) } + test("derive encoder of FiniteDuration") { + val data = Seq(ColorData(Color.Black), ColorData(Color.Red)).toDS() + .map(_.copy(Color.Red)) + assertEquals( + data.schema, + StructType(Seq(StructField("color", StringType, true))) + ) + assertEquals(data.collect().toSeq, Seq(ColorData(Color.Red), ColorData(Color.Red))) + } + test("List[Int]") { val ls = List(List(1, 2, 3), List(4, 5, 6)) assertEquals(ls.toDS.collect().toList, ls)