diff --git a/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala b/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala index 3f2a9a5..36049d2 100644 --- a/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala +++ b/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala @@ -1,20 +1,16 @@ package scala3encoders.derivation -import scala.compiletime.{constValue, summonInline, erasedValue} +import scala.compiletime.{constValue, erasedValue, summonInline} import scala.deriving.Mirror import scala.reflect.{ClassTag, Enum} - -import org.apache.spark.sql.catalyst.expressions.{ - Expression, - If, - IsNull, - Literal -} +import org.apache.spark.sql.catalyst.expressions.{Expression, If, IsNull, Literal} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.* import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.expressions.objects.* import org.apache.spark.sql.helper.Helper import org.apache.spark.sql.types.* +import scala.concurrent.duration.FiniteDuration +import scala.jdk.javaapi.DurationConverters trait Deserializer[T]: def inputType: DataType @@ -116,6 +112,18 @@ object Deserializer: def deserialize(path: Expression): Expression = createDeserializerForDuration(path) + given Deserializer[FiniteDuration] with + def inputType: DataType = DayTimeIntervalType() + def deserialize(path: Expression): Expression = + val javaDuration = summon[Deserializer[java.time.Duration]].deserialize(path) + StaticInvoke( + DurationConverters.getClass, + ObjectType(classOf[FiniteDuration]), + "toScala", + javaDuration :: Nil, + returnNullable = false + ) + given Deserializer[java.time.Period] with def inputType: DataType = YearMonthIntervalType() def deserialize(path: Expression): Expression = diff --git a/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala b/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala index 47196de..d671f42 100644 --- a/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala +++ b/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala @@ -1,15 +1,15 @@ package scala3encoders.derivation -import scala.compiletime.{constValue, summonInline, erasedValue} +import scala.compiletime.{constValue, erasedValue, summonInline} import scala.deriving.Mirror import scala.reflect.{ClassTag, Enum} - import org.apache.spark.sql.catalyst.expressions.{Expression, KnownNotNull} -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke, UnwrapOption} import org.apache.spark.sql.catalyst.SerializerBuildHelper.* import org.apache.spark.sql.helper.Helper import org.apache.spark.sql.types.* -import org.apache.spark.sql.catalyst.expressions.objects.UnwrapOption +import scala.concurrent.duration.FiniteDuration +import scala.jdk.javaapi.DurationConverters trait Serializer[T]: def inputType: DataType @@ -94,6 +94,18 @@ object Serializer: def serialize(inputObject: Expression): Expression = createSerializerForJavaDuration(inputObject) + given Serializer[FiniteDuration] with + def inputType: DataType = ObjectType(classOf[FiniteDuration]) + def serialize(inputObject: Expression): Expression = + val javaDuration = StaticInvoke( + DurationConverters.getClass, + ObjectType(classOf[java.time.Duration]), + "toJava", + inputObject :: Nil, + returnNullable = false + ) + summon[Serializer[java.time.Duration]].serialize(javaDuration) + given Serializer[java.time.Period] with def inputType: DataType = ObjectType(classOf[java.time.Period]) def serialize(inputObject: Expression): Expression = @@ -118,7 +130,7 @@ object Serializer: def inputType: DataType = ObjectType(classOf[String]) def serialize(inputObject: Expression): Expression = createSerializerForString(inputObject) - + given [E <: Enum: ClassTag]: Serializer[E] with def inputType: DataType = ObjectType(summon[ClassTag[E]].runtimeClass) def serialize(inputObject: Expression): Expression = diff --git a/encoders/src/test/scala/sql/EncoderDerivationSpec.scala b/encoders/src/test/scala/sql/EncoderDerivationSpec.scala index 806251f..b192d36 100644 --- a/encoders/src/test/scala/sql/EncoderDerivationSpec.scala +++ b/encoders/src/test/scala/sql/EncoderDerivationSpec.scala @@ -1,10 +1,11 @@ package scala3encoders import org.apache.spark.sql.{AnalysisException, Encoder} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.* +import org.apache.spark.sql.functions.* import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import java.io.{File, PrintWriter} +import scala.concurrent.duration.* case class A() case class B(x: String) @@ -133,6 +134,7 @@ case class Sequence(id: Int, nums: Seq[Int]) case class City(name: String, lat: Double, lon: Double) case class CityWithInts(name: String, lat: Int, lon: Int) case class Journey(id: Int, cities: Seq[City]) +case class DurationData(duration: FiniteDuration) enum Color: case Red, Black @@ -152,7 +154,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: assertEquals(encoder.schema, StructType(Seq.empty)) val input = Seq(A(), A()) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test("derive encoder of case class B(x: String)") { @@ -163,7 +165,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: ) val input = Seq(B("hello"), B("world")) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test("derive encoder of case class C(x: Int, y: Long)") { @@ -180,7 +182,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: val input = Seq(C(42, -9_223_372_036_854_775_808L), C(0, 9_223_372_036_854_775_807L)) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test("derive encoder of case class Pos and collect as tuple") { @@ -213,7 +215,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: ) val input = Seq(D("Hello", B("World")), D("Bye", B("Universe"))) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test( @@ -225,7 +227,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: E(null, Array(1, 2), Set(1.0, 2.0)), E(Map(), null, null) ) - val res = input.toDS.collect.toSeq + val res = input.toDS().collect.toSeq assertEquals(res.map(_.x), input.map(_.x)) assert( res @@ -245,7 +247,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: F(None, Some(1L), (-1, 0, 1)), F(null, null, (0, 0, 0)) ) - val res = input.toDS.collect.toSeq + val res = input.toDS().collect.toSeq assertEquals(res(0), input(0)) assertEquals(res(1), input(1)) // null will be mapped to None @@ -268,7 +270,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: 2 ) ) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test("derive encoder of case class A()") { @@ -276,7 +278,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: assertEquals(encoder.schema, StructType(Seq.empty)) val input = Seq(A(), A()) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test("derive encoder of FiniteDuration") { @@ -291,7 +293,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: test("List[Int]") { val ls = List(List(1, 2, 3), List(4, 5, 6)) - assertEquals(ls.toDS.collect().toList, ls) + assertEquals(ls.toDS().collect().toList, ls) } test("List[case class]") { @@ -302,7 +304,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: val seqs = seq1 :: seq2 :: seq3 :: Nil assertEquals( - seqs.toDS.collect().toList, + seqs.toDS().collect().toList, seqs ) } @@ -319,7 +321,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: val trips = trip1 :: trip2 :: trip3 :: Nil - val idsIncrement = trips.toDS.map(tr => tr.copy(id = tr.id + 1)) + val idsIncrement = trips.toDS().map(tr => tr.copy(id = tr.id + 1)) assertEquals( idsIncrement.collect().toList, @@ -358,7 +360,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: Map(Key("foo", java.time.LocalDate.now().minusDays(10)) -> 123L) ) ) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) } test("check Big class") { @@ -371,7 +373,17 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting: 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99) ) - assertEquals(input.toDS.collect.toSeq, input) + assertEquals(input.toDS().collect.toSeq, input) + } + + test("derive encoder of FiniteDuration") { + val data = Seq(DurationData(1.minute), DurationData(2.seconds)).toDS() + .map(row => row.copy(duration = row.duration * 2)) + assertEquals( + data.schema, + StructType(Seq(StructField("duration", DayTimeIntervalType(startField = 0, endField = 3), true))) + ) + assertEquals(data.collect().toSeq, Seq(DurationData(2.minute), DurationData(4.seconds))) } if (spark.version.split("\\.")(1).toInt > 3) then