Skip to content

Commit

Permalink
Add support for FiniteDuration (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
joan38 authored Oct 26, 2023
1 parent 2c88a28 commit 81936fb
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
package scala3encoders.derivation

import scala.compiletime.{constValue, summonInline, erasedValue}
import scala.compiletime.{constValue, erasedValue, summonInline}
import scala.deriving.Mirror
import scala.reflect.{ClassTag, Enum}

import org.apache.spark.sql.catalyst.expressions.{
Expression,
If,
IsNull,
Literal
}
import org.apache.spark.sql.catalyst.expressions.{Expression, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.*
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.expressions.objects.*
import org.apache.spark.sql.helper.Helper
import org.apache.spark.sql.types.*
import scala.concurrent.duration.FiniteDuration
import scala.jdk.javaapi.DurationConverters

trait Deserializer[T]:
def inputType: DataType
Expand Down Expand Up @@ -116,6 +112,18 @@ object Deserializer:
def deserialize(path: Expression): Expression =
createDeserializerForDuration(path)

given Deserializer[FiniteDuration] with
def inputType: DataType = DayTimeIntervalType()
def deserialize(path: Expression): Expression =
val javaDuration = summon[Deserializer[java.time.Duration]].deserialize(path)
StaticInvoke(
DurationConverters.getClass,
ObjectType(classOf[FiniteDuration]),
"toScala",
javaDuration :: Nil,
returnNullable = false
)

given Deserializer[java.time.Period] with
def inputType: DataType = YearMonthIntervalType()
def deserialize(path: Expression): Expression =
Expand Down
22 changes: 17 additions & 5 deletions encoders/src/main/scala/scala3encoders/derivation/Serializer.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package scala3encoders.derivation

import scala.compiletime.{constValue, summonInline, erasedValue}
import scala.compiletime.{constValue, erasedValue, summonInline}
import scala.deriving.Mirror
import scala.reflect.{ClassTag, Enum}

import org.apache.spark.sql.catalyst.expressions.{Expression, KnownNotNull}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke, UnwrapOption}
import org.apache.spark.sql.catalyst.SerializerBuildHelper.*
import org.apache.spark.sql.helper.Helper
import org.apache.spark.sql.types.*
import org.apache.spark.sql.catalyst.expressions.objects.UnwrapOption
import scala.concurrent.duration.FiniteDuration
import scala.jdk.javaapi.DurationConverters

trait Serializer[T]:
def inputType: DataType
Expand Down Expand Up @@ -94,6 +94,18 @@ object Serializer:
def serialize(inputObject: Expression): Expression =
createSerializerForJavaDuration(inputObject)

given Serializer[FiniteDuration] with
def inputType: DataType = ObjectType(classOf[FiniteDuration])
def serialize(inputObject: Expression): Expression =
val javaDuration = StaticInvoke(
DurationConverters.getClass,
ObjectType(classOf[java.time.Duration]),
"toJava",
inputObject :: Nil,
returnNullable = false
)
summon[Serializer[java.time.Duration]].serialize(javaDuration)

given Serializer[java.time.Period] with
def inputType: DataType = ObjectType(classOf[java.time.Period])
def serialize(inputObject: Expression): Expression =
Expand All @@ -118,7 +130,7 @@ object Serializer:
def inputType: DataType = ObjectType(classOf[String])
def serialize(inputObject: Expression): Expression =
createSerializerForString(inputObject)

given [E <: Enum: ClassTag]: Serializer[E] with
def inputType: DataType = ObjectType(summon[ClassTag[E]].runtimeClass)
def serialize(inputObject: Expression): Expression =
Expand Down
42 changes: 27 additions & 15 deletions encoders/src/test/scala/sql/EncoderDerivationSpec.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package scala3encoders

import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.*
import org.apache.spark.sql.functions.*
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import java.io.{File, PrintWriter}
import scala.concurrent.duration.*

case class A()
case class B(x: String)
Expand Down Expand Up @@ -133,6 +134,7 @@ case class Sequence(id: Int, nums: Seq[Int])
case class City(name: String, lat: Double, lon: Double)
case class CityWithInts(name: String, lat: Int, lon: Int)
case class Journey(id: Int, cities: Seq[City])
case class DurationData(duration: FiniteDuration)

enum Color:
case Red, Black
Expand All @@ -152,7 +154,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
assertEquals(encoder.schema, StructType(Seq.empty))

val input = Seq(A(), A())
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("derive encoder of case class B(x: String)") {
Expand All @@ -163,7 +165,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
)

val input = Seq(B("hello"), B("world"))
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("derive encoder of case class C(x: Int, y: Long)") {
Expand All @@ -180,7 +182,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:

val input =
Seq(C(42, -9_223_372_036_854_775_808L), C(0, 9_223_372_036_854_775_807L))
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("derive encoder of case class Pos and collect as tuple") {
Expand Down Expand Up @@ -213,7 +215,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
)

val input = Seq(D("Hello", B("World")), D("Bye", B("Universe")))
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test(
Expand All @@ -225,7 +227,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
E(null, Array(1, 2), Set(1.0, 2.0)),
E(Map(), null, null)
)
val res = input.toDS.collect.toSeq
val res = input.toDS().collect.toSeq
assertEquals(res.map(_.x), input.map(_.x))
assert(
res
Expand All @@ -245,7 +247,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
F(None, Some(1L), (-1, 0, 1)),
F(null, null, (0, 0, 0))
)
val res = input.toDS.collect.toSeq
val res = input.toDS().collect.toSeq
assertEquals(res(0), input(0))
assertEquals(res(1), input(1))
// null will be mapped to None
Expand All @@ -268,15 +270,15 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
2
)
)
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("derive encoder of case class A()") {
val encoder = summon[Encoder[A]]
assertEquals(encoder.schema, StructType(Seq.empty))

val input = Seq(A(), A())
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("derive encoder of FiniteDuration") {
Expand All @@ -291,7 +293,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:

test("List[Int]") {
val ls = List(List(1, 2, 3), List(4, 5, 6))
assertEquals(ls.toDS.collect().toList, ls)
assertEquals(ls.toDS().collect().toList, ls)
}

test("List[case class]") {
Expand All @@ -302,7 +304,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
val seqs = seq1 :: seq2 :: seq3 :: Nil

assertEquals(
seqs.toDS.collect().toList,
seqs.toDS().collect().toList,
seqs
)
}
Expand All @@ -319,7 +321,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:

val trips = trip1 :: trip2 :: trip3 :: Nil

val idsIncrement = trips.toDS.map(tr => tr.copy(id = tr.id + 1))
val idsIncrement = trips.toDS().map(tr => tr.copy(id = tr.id + 1))

assertEquals(
idsIncrement.collect().toList,
Expand Down Expand Up @@ -358,7 +360,7 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
Map(Key("foo", java.time.LocalDate.now().minusDays(10)) -> 123L)
)
)
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("check Big class") {
Expand All @@ -371,7 +373,17 @@ class EncoderDerivationSpec extends munit.FunSuite with SparkSqlTesting:
74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
92, 93, 94, 95, 96, 97, 98, 99)
)
assertEquals(input.toDS.collect.toSeq, input)
assertEquals(input.toDS().collect.toSeq, input)
}

test("derive encoder of FiniteDuration") {
val data = Seq(DurationData(1.minute), DurationData(2.seconds)).toDS()
.map(row => row.copy(duration = row.duration * 2))
assertEquals(
data.schema,
StructType(Seq(StructField("duration", DayTimeIntervalType(startField = 0, endField = 3), true)))
)
assertEquals(data.collect().toSeq, Seq(DurationData(2.minute), DurationData(4.seconds)))
}

if (spark.version.split("\\.")(1).toInt > 3) then
Expand Down

0 comments on commit 81936fb

Please sign in to comment.