Skip to content

Commit

Permalink
Merge pull request #42 from michael72/main
Browse files Browse the repository at this point in the history
upgrade to spark 3.5.0 + cleanup
  • Loading branch information
vincenzobaz authored Oct 1, 2023
2 parents 4b56d12 + b5f8bf7 commit 48b2b50
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 48 deletions.
7 changes: 5 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
ThisBuild / scalaVersion := "3.3.0"
ThisBuild / scalaVersion := "3.3.1"
ThisBuild / semanticdbEnabled := true
ThisBuild / scalacOptions ++= List(
"-Wunused:imports"
)

val sparkVersion = "3.3.2"
val sparkVersion = "3.5.0"
val sparkSql = ("org.apache.spark" %% "spark-sql" % sparkVersion).cross(
CrossVersion.for3Use2_13
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package scala3encoders
import scala3encoders.derivation.{Deserializer, Serializer}
import scala.reflect.ClassTag

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,15 @@ import scala.compiletime.{constValue, summonInline, erasedValue}
import scala.deriving.Mirror
import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.expressions.{
Expression,
If,
IsNull,
Literal
}
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.*
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.helper.Helper

import org.apache.spark.sql.types.*
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.catalyst.expressions.GetStructField

trait Deserializer[T]:
def inputType: DataType
Expand All @@ -36,7 +30,7 @@ object Deserializer:
override def inputType: DataType = d.inputType

override def deserialize(path: Expression): Expression =
val tpe = ScalaReflection.typeBoxedJavaMapping.getOrElse(
val tpe = Helper.typeBoxedJavaMapping.getOrElse(
d.inputType,
ct.runtimeClass
)
Expand Down Expand Up @@ -125,14 +119,6 @@ object Deserializer:
def deserialize(path: Expression): Expression =
createDeserializerForPeriod(path)

/*given deriveEnum[T](using d: Deserializer[T], ct: ClassTag[T]): Deserializer[java.lang.Enum[T]] with
def inputType: DataType = StringType
def deserialize(path: Expression): Expression =
createDeserializerForTypesSupportValueOf(
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false),
// TODO !!
ct.getClass())*/

given Deserializer[String] with
def inputType: DataType = StringType
def deserialize(path: Expression): Expression =
Expand Down Expand Up @@ -165,12 +151,12 @@ object Deserializer:
override def inputType: DataType = ArrayType(d.inputType)
override def deserialize(path: Expression): Expression =
val mapFunction: Expression => Expression = el =>
deserializerForWithNullSafetyAndUpcast(
Helper.deserializerForWithNullSafetyAndUpcast(
el,
d.inputType,
true,
WalkedTypePath(Nil),
(casted, _) => d.deserialize(casted)
d.deserialize
)
val arrayClass = ObjectType(ct.newArray(0).getClass)
val arrayData = UnresolvedMapObjects(mapFunction, path)
Expand All @@ -196,12 +182,12 @@ object Deserializer:
override def inputType: DataType = ArrayType(d.inputType)
override def deserialize(path: Expression): Expression =
val mapFunction: Expression => Expression = element =>
deserializerForWithNullSafetyAndUpcast(
Helper.deserializerForWithNullSafetyAndUpcast(
element,
d.inputType,
nullable = true,
WalkedTypePath(Nil),
(casted, _) => d.deserialize(casted)
d.deserialize
)
UnresolvedMapObjects(mapFunction, path, Some(classOf[Seq[T]]))

Expand Down
75 changes: 75 additions & 0 deletions encoders/src/main/scala/scala3encoders/derivation/Helper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package org.apache.spark.sql.helper

import org.apache.spark.sql.catalyst.expressions.{
CheckOverflow,
Expression,
UpCast
}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

// This is copied from spark to support older versions of Spark and 3.5.0 -
// it was part of ScalaReflection and was moved to EncoderUtils in 3.5.0
object Helper {
private val nullOnOverflow = !SQLConf.get.ansiEnabled

val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]](
BooleanType -> classOf[java.lang.Boolean],
ByteType -> classOf[java.lang.Byte],
ShortType -> classOf[java.lang.Short],
IntegerType -> classOf[java.lang.Integer],
LongType -> classOf[java.lang.Long],
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long],
TimestampNTZType -> classOf[java.lang.Long]
)

def createSerializerForBigInteger(inputObject: Expression): Expression = {
CheckOverflow(
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil,
returnNullable = false
),
DecimalType.BigIntDecimal,
nullOnOverflow
)
}

private def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: WalkedTypePath
): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
case _: DecimalType =>
// For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale.
// Here we use the `DecimalType` object to indicate it.
UpCast(expr, DecimalType, walkedTypePath.getPaths)
case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}

def deserializerForWithNullSafetyAndUpcast(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: WalkedTypePath,
funcForCreatingDeserializer: Expression => Expression
): Expression = {
val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
expressionWithNullSafety(
funcForCreatingDeserializer(casted),
nullable,
walkedTypePath
)
}
}
14 changes: 4 additions & 10 deletions encoders/src/main/scala/scala3encoders/derivation/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.expressions.{Expression, KnownNotNull}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.SerializerBuildHelper.*
import org.apache.spark.sql.helper.Helper
import org.apache.spark.sql.types.*
import org.apache.spark.sql.catalyst.expressions.objects.UnwrapOption
import org.apache.spark.sql.catalyst.ScalaReflection

trait Serializer[T]:
def inputType: DataType
Expand Down Expand Up @@ -102,23 +102,17 @@ object Serializer:
given Serializer[BigDecimal] with
def inputType: DataType = ObjectType(classOf[BigDecimal])
def serialize(inputObject: Expression): Expression =
createSerializerForScalaBigDecimal(inputObject)
Helper.createSerializerForBigInteger(inputObject)

given Serializer[java.math.BigInteger] with
def inputType: DataType = ObjectType(classOf[java.math.BigInteger])
def serialize(inputObject: Expression): Expression =
createSerializerForJavaBigInteger(inputObject)
Helper.createSerializerForBigInteger(inputObject)

given Serializer[scala.math.BigInt] with
def inputType: DataType = ObjectType(classOf[scala.math.BigInt])
def serialize(inputObject: Expression): Expression =
createSerializerForScalaBigInt(inputObject)

// TODO
/*given Serializer[Enum[_]] with
def inputType: DataType = ObjectType(classOf[Enum[_]])
def serialize(inputObject: Expression): Expression =
createSerializerForJavaEnum(inputObject)*/
Helper.createSerializerForBigInteger(inputObject)

given Serializer[String] with
def inputType: DataType = ObjectType(classOf[String])
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/scala/rdd/WordCountSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala3encoders.given
@main def wordcountSql =
val spark = SparkSession.builder().master("local").getOrCreate

import spark.implicits.{StringToColumn, rddToDatasetHolder}
import spark.implicits.rddToDatasetHolder

try
val sc = spark.sparkContext
Expand Down
6 changes: 1 addition & 5 deletions examples/src/main/scala/sql/StarWars.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package sql

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{Dataset, DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql._
import buildinfo.BuildInfo.inputDirectory
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}

object StarWars extends App:
val spark = SparkSession.builder().master("local").getOrCreate
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.9.3
sbt.version=1.9.6
3 changes: 0 additions & 3 deletions udf/src/main/scala/scala3udf/Udf.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package scala3udf

import scala.reflect.ClassTag

import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.expressions.{Exporter, UserDefinedFunction}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import scala.compiletime.{summonInline, erasedValue}
import scala.deriving.Mirror
import scala.quoted.*

import scala3encoders.derivation.Deserializer
Expand Down
1 change: 0 additions & 1 deletion udf/src/test/scala/scala3udf/UdfSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import scala3udf.{
import scala3encoders.given

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.encoders.RowEncoder

case class DataWithPos(name: String, x: Int, y: Int, z: Int)
case class DataWithX(name: String, x: Int)
Expand Down

0 comments on commit 48b2b50

Please sign in to comment.