diff --git a/build.sbt b/build.sbt index 3400ca2..4c55e8e 100644 --- a/build.sbt +++ b/build.sbt @@ -1,7 +1,10 @@ -ThisBuild / scalaVersion := "3.3.0" +ThisBuild / scalaVersion := "3.3.1" ThisBuild / semanticdbEnabled := true +ThisBuild / scalacOptions ++= List( + "-Wunused:imports" +) -val sparkVersion = "3.3.2" +val sparkVersion = "3.5.0" val sparkSql = ("org.apache.spark" %% "spark-sql" % sparkVersion).cross( CrossVersion.for3Use2_13 ) diff --git a/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala b/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala index 3534825..32240a3 100644 --- a/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala +++ b/encoders/src/main/scala/scala3encoders/EncoderDerivation.scala @@ -3,7 +3,6 @@ package scala3encoders import scala3encoders.derivation.{Deserializer, Serializer} import scala.reflect.ClassTag -import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal diff --git a/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala b/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala index ba4f36e..e111e4f 100644 --- a/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala +++ b/encoders/src/main/scala/scala3encoders/derivation/Deserializer.scala @@ -4,21 +4,15 @@ import scala.compiletime.{constValue, summonInline, erasedValue} import scala.deriving.Mirror import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions.{ - Expression, - If, - IsNull, - Literal -} -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.* +import org.apache.spark.sql.catalyst.WalkedTypePath +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.GetStructField import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.helper.Helper import org.apache.spark.sql.types.* -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.expressions.GetStructField trait Deserializer[T]: def inputType: DataType @@ -36,7 +30,7 @@ object Deserializer: override def inputType: DataType = d.inputType override def deserialize(path: Expression): Expression = - val tpe = ScalaReflection.typeBoxedJavaMapping.getOrElse( + val tpe = Helper.typeBoxedJavaMapping.getOrElse( d.inputType, ct.runtimeClass ) @@ -125,14 +119,6 @@ object Deserializer: def deserialize(path: Expression): Expression = createDeserializerForPeriod(path) - /*given deriveEnum[T](using d: Deserializer[T], ct: ClassTag[T]): Deserializer[java.lang.Enum[T]] with - def inputType: DataType = StringType - def deserialize(path: Expression): Expression = - createDeserializerForTypesSupportValueOf( - Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), - // TODO !! - ct.getClass())*/ - given Deserializer[String] with def inputType: DataType = StringType def deserialize(path: Expression): Expression = @@ -165,12 +151,12 @@ object Deserializer: override def inputType: DataType = ArrayType(d.inputType) override def deserialize(path: Expression): Expression = val mapFunction: Expression => Expression = el => - deserializerForWithNullSafetyAndUpcast( + Helper.deserializerForWithNullSafetyAndUpcast( el, d.inputType, true, WalkedTypePath(Nil), - (casted, _) => d.deserialize(casted) + d.deserialize ) val arrayClass = ObjectType(ct.newArray(0).getClass) val arrayData = UnresolvedMapObjects(mapFunction, path) @@ -196,12 +182,12 @@ object Deserializer: override def inputType: DataType = ArrayType(d.inputType) override def deserialize(path: Expression): Expression = val mapFunction: Expression => Expression = element => - deserializerForWithNullSafetyAndUpcast( + Helper.deserializerForWithNullSafetyAndUpcast( element, d.inputType, nullable = true, WalkedTypePath(Nil), - (casted, _) => d.deserialize(casted) + d.deserialize ) UnresolvedMapObjects(mapFunction, path, Some(classOf[Seq[T]])) diff --git a/encoders/src/main/scala/scala3encoders/derivation/Helper.scala b/encoders/src/main/scala/scala3encoders/derivation/Helper.scala new file mode 100644 index 0000000..b4f4ec9 --- /dev/null +++ b/encoders/src/main/scala/scala3encoders/derivation/Helper.scala @@ -0,0 +1,75 @@ +package org.apache.spark.sql.helper + +import org.apache.spark.sql.catalyst.expressions.{ + CheckOverflow, + Expression, + UpCast +} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety +import org.apache.spark.sql.catalyst.WalkedTypePath +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +// This is copied from spark to support older versions of Spark and 3.5.0 - +// it was part of ScalaReflection and was moved to EncoderUtils in 3.5.0 +object Helper { + private val nullOnOverflow = !SQLConf.get.ansiEnabled + + val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long], + TimestampNTZType -> classOf[java.lang.Long] + ) + + def createSerializerForBigInteger(inputObject: Expression): Expression = { + CheckOverflow( + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil, + returnNullable = false + ), + DecimalType.BigIntDecimal, + nullOnOverflow + ) + } + + private def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: WalkedTypePath + ): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + case _: MapType => expr + case _: DecimalType => + // For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale. + // Here we use the `DecimalType` object to indicate it. + UpCast(expr, DecimalType, walkedTypePath.getPaths) + case _ => UpCast(expr, expected, walkedTypePath.getPaths) + } + + def deserializerForWithNullSafetyAndUpcast( + expr: Expression, + dataType: DataType, + nullable: Boolean, + walkedTypePath: WalkedTypePath, + funcForCreatingDeserializer: Expression => Expression + ): Expression = { + val casted = upCastToExpectedType(expr, dataType, walkedTypePath) + expressionWithNullSafety( + funcForCreatingDeserializer(casted), + nullable, + walkedTypePath + ) + } +} diff --git a/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala b/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala index 585a6db..14bcb4c 100644 --- a/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala +++ b/encoders/src/main/scala/scala3encoders/derivation/Serializer.scala @@ -7,9 +7,9 @@ import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.expressions.{Expression, KnownNotNull} import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.SerializerBuildHelper.* +import org.apache.spark.sql.helper.Helper import org.apache.spark.sql.types.* import org.apache.spark.sql.catalyst.expressions.objects.UnwrapOption -import org.apache.spark.sql.catalyst.ScalaReflection trait Serializer[T]: def inputType: DataType @@ -102,23 +102,17 @@ object Serializer: given Serializer[BigDecimal] with def inputType: DataType = ObjectType(classOf[BigDecimal]) def serialize(inputObject: Expression): Expression = - createSerializerForScalaBigDecimal(inputObject) + Helper.createSerializerForBigInteger(inputObject) given Serializer[java.math.BigInteger] with def inputType: DataType = ObjectType(classOf[java.math.BigInteger]) def serialize(inputObject: Expression): Expression = - createSerializerForJavaBigInteger(inputObject) + Helper.createSerializerForBigInteger(inputObject) given Serializer[scala.math.BigInt] with def inputType: DataType = ObjectType(classOf[scala.math.BigInt]) def serialize(inputObject: Expression): Expression = - createSerializerForScalaBigInt(inputObject) - - // TODO - /*given Serializer[Enum[_]] with - def inputType: DataType = ObjectType(classOf[Enum[_]]) - def serialize(inputObject: Expression): Expression = - createSerializerForJavaEnum(inputObject)*/ + Helper.createSerializerForBigInteger(inputObject) given Serializer[String] with def inputType: DataType = ObjectType(classOf[String]) diff --git a/examples/src/main/scala/rdd/WordCountSql.scala b/examples/src/main/scala/rdd/WordCountSql.scala index ccaa642..ab43525 100644 --- a/examples/src/main/scala/rdd/WordCountSql.scala +++ b/examples/src/main/scala/rdd/WordCountSql.scala @@ -9,7 +9,7 @@ import scala3encoders.given @main def wordcountSql = val spark = SparkSession.builder().master("local").getOrCreate - import spark.implicits.{StringToColumn, rddToDatasetHolder} + import spark.implicits.rddToDatasetHolder try val sc = spark.sparkContext diff --git a/examples/src/main/scala/sql/StarWars.scala b/examples/src/main/scala/sql/StarWars.scala index 9485058..56e1f54 100644 --- a/examples/src/main/scala/sql/StarWars.scala +++ b/examples/src/main/scala/sql/StarWars.scala @@ -1,11 +1,7 @@ package sql -import org.apache.spark.sql.SparkSession - -import org.apache.spark.sql.{Dataset, DataFrame, SparkSession} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql._ import buildinfo.BuildInfo.inputDirectory +import org.apache.spark.sql.{Dataset, Encoder, SparkSession} object StarWars extends App: val spark = SparkSession.builder().master("local").getOrCreate diff --git a/project/build.properties b/project/build.properties index 52413ab..2743082 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.9.3 +sbt.version=1.9.6 diff --git a/udf/src/main/scala/scala3udf/Udf.scala b/udf/src/main/scala/scala3udf/Udf.scala index 75999b9..e4678d7 100644 --- a/udf/src/main/scala/scala3udf/Udf.scala +++ b/udf/src/main/scala/scala3udf/Udf.scala @@ -1,14 +1,11 @@ package scala3udf -import scala.reflect.ClassTag - import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.expressions.{Exporter, UserDefinedFunction} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import scala.compiletime.{summonInline, erasedValue} -import scala.deriving.Mirror import scala.quoted.* import scala3encoders.derivation.Deserializer diff --git a/udf/src/test/scala/scala3udf/UdfSpec.scala b/udf/src/test/scala/scala3udf/UdfSpec.scala index ab44076..2bb087e 100644 --- a/udf/src/test/scala/scala3udf/UdfSpec.scala +++ b/udf/src/test/scala/scala3udf/UdfSpec.scala @@ -9,7 +9,6 @@ import scala3udf.{ import scala3encoders.given import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.encoders.RowEncoder case class DataWithPos(name: String, x: Int, y: Int, z: Int) case class DataWithX(name: String, x: Int)