diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt index 12b81a0b..9c570738 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt @@ -36,6 +36,7 @@ import scala.collection.Iterable as ScalaIterable import scala.collection.Iterator as ScalaIterator import scala.collection.Map as ScalaMap import scala.collection.Seq as ScalaSeq +import scala.collection.immutable.Seq as ScalaImmutableSeq import scala.collection.Set as ScalaSet import scala.collection.concurrent.Map as ScalaConcurrentMap import scala.collection.mutable.Buffer as ScalaMutableBuffer @@ -124,6 +125,18 @@ fun Collection.asScalaIterable(): ScalaIterable = //$scala.collection.JavaConverters.collectionAsScalaIterable(this) //#endif +//#if scalaCompat >= 2.13 +/** @see scala.jdk.javaapi.CollectionConverters.asScala for more information. */ +//#else +//$/** @see scala.collection.JavaConverters.iterableAsScalaIterable for more information. */ +//#endif +fun Iterable.asScalaSeq(): ScalaImmutableSeq = + //#if scalaCompat >= 2.13 + scala.jdk.javaapi.CollectionConverters.asScala(this).toSeq() + //#else + //$scala.collection.JavaConverters.iterableAsScalaIterable(this).toSeq() + //#endif + //#if scalaCompat >= 2.13 /** @see scala.jdk.javaapi.CollectionConverters.asScala for more information. */ //#else @@ -363,803 +376,3 @@ fun ScalaConcurrentMap.asKotlinConcurrentMap(): ConcurrentMap //#else //$scala.collection.JavaConverters.mapAsJavaConcurrentMap(this) //#endif - - -/** - * Returns a new [Arity2] based on the arguments in the current [Pair]. - */ -@Deprecated("Use Scala tuples instead.", ReplaceWith("this.toTuple()", "scala.Tuple2")) -fun Pair.toArity(): Arity2 = Arity2(first, second) - -/** - * Returns a new [Pair] based on the arguments in the current [Arity2]. - */ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity2.toPair(): Pair = Pair(_1, _2) - -/** - * Returns a new [Arity3] based on the arguments in the current [Triple]. - */ -@Deprecated("Use Scala tuples instead.", ReplaceWith("this.toTuple()", "scala.Tuple3")) -fun Triple.toArity(): Arity3 = Arity3(first, second, third) - -/** - * Returns a new [Triple] based on the arguments in the current [Arity3]. - */ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity3.toTriple(): Triple = Triple(_1, _2, _3) - - -/** - * Returns a new Arity1 based on this Tuple1. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple1.toArity(): Arity1 = Arity1(this._1()) - -/** - * Returns a new Arity2 based on this Tuple2. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple2.toArity(): Arity2 = Arity2(this._1(), this._2()) - -/** - * Returns a new Arity3 based on this Tuple3. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple3.toArity(): Arity3 = Arity3(this._1(), this._2(), this._3()) - -/** - * Returns a new Arity4 based on this Tuple4. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple4.toArity(): Arity4 = - Arity4(this._1(), this._2(), this._3(), this._4()) - -/** - * Returns a new Arity5 based on this Tuple5. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple5.toArity(): Arity5 = - Arity5(this._1(), this._2(), this._3(), this._4(), this._5()) - -/** - * Returns a new Arity6 based on this Tuple6. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple6.toArity(): Arity6 = - Arity6(this._1(), this._2(), this._3(), this._4(), this._5(), this._6()) - -/** - * Returns a new Arity7 based on this Tuple7. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple7.toArity(): Arity7 = - Arity7(this._1(), this._2(), this._3(), this._4(), this._5(), this._6(), this._7()) - -/** - * Returns a new Arity8 based on this Tuple8. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple8.toArity(): Arity8 = - Arity8( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8() - ) - -/** - * Returns a new Arity9 based on this Tuple9. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple9.toArity(): Arity9 = - Arity9( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9() - ) - -/** - * Returns a new Arity10 based on this Tuple10. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple10.toArity(): Arity10 = - Arity10( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10() - ) - -/** - * Returns a new Arity11 based on this Tuple11. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple11.toArity(): Arity11 = - Arity11( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11() - ) - -/** - * Returns a new Arity12 based on this Tuple12. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple12.toArity(): Arity12 = - Arity12( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12() - ) - -/** - * Returns a new Arity13 based on this Tuple13. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple13.toArity(): Arity13 = - Arity13( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13() - ) - -/** - * Returns a new Arity14 based on this Tuple14. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple14.toArity(): Arity14 = - Arity14( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14() - ) - -/** - * Returns a new Arity15 based on this Tuple15. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) - -fun Tuple15.toArity(): Arity15 = - Arity15( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15() - ) - -/** - * Returns a new Arity16 based on this Tuple16. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple16.toArity(): Arity16 = - Arity16( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16() - ) - -/** - * Returns a new Arity17 based on this Tuple17. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple17.toArity(): Arity17 = - Arity17( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16(), - this._17() - ) - -/** - * Returns a new Arity18 based on this Tuple18. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple18.toArity(): Arity18 = - Arity18( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16(), - this._17(), - this._18() - ) - -/** - * Returns a new Arity19 based on this Tuple19. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple19.toArity(): Arity19 = - Arity19( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16(), - this._17(), - this._18(), - this._19() - ) - -/** - * Returns a new Arity20 based on this Tuple20. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple20.toArity(): Arity20 = - Arity20( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16(), - this._17(), - this._18(), - this._19(), - this._20() - ) - -/** - * Returns a new Arity21 based on this Tuple21. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple21.toArity(): Arity21 = - Arity21( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16(), - this._17(), - this._18(), - this._19(), - this._20(), - this._21() - ) - -/** - * Returns a new Arity22 based on this Tuple22. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Tuple22.toArity(): Arity22 = - Arity22( - this._1(), - this._2(), - this._3(), - this._4(), - this._5(), - this._6(), - this._7(), - this._8(), - this._9(), - this._10(), - this._11(), - this._12(), - this._13(), - this._14(), - this._15(), - this._16(), - this._17(), - this._18(), - this._19(), - this._20(), - this._21(), - this._22() - ) - -/** - * Returns a new Tuple1 based on this Arity1. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity1.toTuple(): Tuple1 = Tuple1(this._1) - -/** - * Returns a new Tuple2 based on this Arity2. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity2.toTuple(): Tuple2 = Tuple2(this._1, this._2) - -/** - * Returns a new Tuple3 based on this Arity3. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity3.toTuple(): Tuple3 = Tuple3(this._1, this._2, this._3) - -/** - * Returns a new Tuple4 based on this Arity4. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity4.toTuple(): Tuple4 = - Tuple4(this._1, this._2, this._3, this._4) - -/** - * Returns a new Tuple5 based on this Arity5. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity5.toTuple(): Tuple5 = - Tuple5(this._1, this._2, this._3, this._4, this._5) - -/** - * Returns a new Tuple6 based on this Arity6. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity6.toTuple(): Tuple6 = - Tuple6(this._1, this._2, this._3, this._4, this._5, this._6) - -/** - * Returns a new Tuple7 based on this Arity7. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity7.toTuple(): Tuple7 = - Tuple7(this._1, this._2, this._3, this._4, this._5, this._6, this._7) - -/** - * Returns a new Tuple8 based on this Arity8. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity8.toTuple(): Tuple8 = - Tuple8(this._1, this._2, this._3, this._4, this._5, this._6, this._7, this._8) - -/** - * Returns a new Tuple9 based on this Arity9. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity9.toTuple(): Tuple9 = - Tuple9( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9 - ) - -/** - * Returns a new Tuple10 based on this Arity10. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity10.toTuple(): Tuple10 = - Tuple10( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10 - ) - -/** - * Returns a new Tuple11 based on this Arity11. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity11.toTuple(): Tuple11 = - Tuple11( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11 - ) - -/** - * Returns a new Tuple12 based on this Arity12. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity12.toTuple(): Tuple12 = - Tuple12( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12 - ) - -/** - * Returns a new Tuple13 based on this Arity13. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity13.toTuple(): Tuple13 = - Tuple13( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13 - ) - -/** - * Returns a new Tuple14 based on this Arity14. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity14.toTuple(): Tuple14 = - Tuple14( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14 - ) - -/** - * Returns a new Tuple15 based on this Arity15. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity15.toTuple(): Tuple15 = - Tuple15( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15 - ) - -/** - * Returns a new Tuple16 based on this Arity16. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity16.toTuple(): Tuple16 = - Tuple16( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16 - ) - -/** - * Returns a new Tuple17 based on this Arity17. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity17.toTuple(): Tuple17 = - Tuple17( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16, - this._17 - ) - -/** - * Returns a new Tuple18 based on this Arity18. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity18.toTuple(): Tuple18 = - Tuple18( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16, - this._17, - this._18 - ) - -/** - * Returns a new Tuple19 based on this Arity19. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity19.toTuple(): Tuple19 = - Tuple19( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16, - this._17, - this._18, - this._19 - ) - -/** - * Returns a new Tuple20 based on this Arity20. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity20.toTuple(): Tuple20 = - Tuple20( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16, - this._17, - this._18, - this._19, - this._20 - ) - -/** - * Returns a new Tuple21 based on this Arity21. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity21.toTuple(): Tuple21 = - Tuple21( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16, - this._17, - this._18, - this._19, - this._20, - this._21 - ) - -/** - * Returns a new Tuple22 based on this Arity22. - **/ -@Deprecated("Use Scala tuples instead.", ReplaceWith("")) -fun Arity22.toTuple(): Tuple22 = - Tuple22( - this._1, - this._2, - this._3, - this._4, - this._5, - this._6, - this._7, - this._8, - this._9, - this._10, - this._11, - this._12, - this._13, - this._14, - this._15, - this._16, - this._17, - this._18, - this._19, - this._20, - this._21, - this._22 - ) diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index ecf62b19..54207bf5 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -29,99 +29,53 @@ package org.jetbrains.kotlinx.spark.api -import org.apache.spark.sql.* -import org.apache.spark.sql.Encoders.* -import org.apache.spark.sql.KotlinReflection.* +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.DefinedByConstructorParams +import org.apache.spark.sql.catalyst.SerializerBuildHelper +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.types.* +import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.Metadata +import org.apache.spark.sql.types.SQLUserDefinedType +import org.apache.spark.sql.types.UDTRegistration +import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.unsafe.types.CalendarInterval -import scala.Product import scala.reflect.ClassTag -import java.beans.PropertyDescriptor -import java.math.BigDecimal -import java.math.BigInteger -import java.sql.Date -import java.sql.Timestamp -import java.time.* -import java.util.* -import java.util.concurrent.ConcurrentHashMap -import kotlin.Any -import kotlin.Array -import kotlin.Boolean -import kotlin.BooleanArray -import kotlin.Byte -import kotlin.ByteArray -import kotlin.Double -import kotlin.DoubleArray -import kotlin.ExperimentalStdlibApi -import kotlin.Float -import kotlin.FloatArray -import kotlin.IllegalArgumentException -import kotlin.Int -import kotlin.IntArray -import kotlin.Long -import kotlin.LongArray -import kotlin.OptIn -import kotlin.Short -import kotlin.ShortArray -import kotlin.String -import kotlin.Suppress -import kotlin.reflect.* -import kotlin.reflect.full.findAnnotation +import java.io.Serializable +import kotlin.reflect.KClass +import kotlin.reflect.KMutableProperty +import kotlin.reflect.KType +import kotlin.reflect.KTypeProjection +import kotlin.reflect.full.createType +import kotlin.reflect.full.declaredMemberProperties import kotlin.reflect.full.hasAnnotation import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.full.primaryConstructor +import kotlin.reflect.full.staticFunctions +import kotlin.reflect.full.withNullability +import kotlin.reflect.jvm.javaMethod import kotlin.reflect.jvm.jvmName -import kotlin.to - -@JvmField -val ENCODERS: Map, Encoder<*>> = mapOf( - Boolean::class to BOOLEAN(), - Byte::class to BYTE(), - Short::class to SHORT(), - Int::class to INT(), - Long::class to LONG(), - Float::class to FLOAT(), - Double::class to DOUBLE(), - String::class to STRING(), - BigDecimal::class to DECIMAL(), - Date::class to DATE(), - LocalDate::class to LOCALDATE(), - Timestamp::class to TIMESTAMP(), - Instant::class to INSTANT(), - ByteArray::class to BINARY(), - //#if sparkMinor >= 3.2 - Duration::class to DURATION(), - Period::class to PERIOD(), - //#endif -) - -private fun checkIfEncoderRequiresNewerVersion(kClass: KClass<*>) { - when (kClass) { - //#if sparkMinor < 3.2 - //$Duration::class, Period::class -> throw IllegalArgumentException("$kClass is supported in Spark 3.2+") - //#endif - } -} - -private val knownDataTypes: Map, DataType> = mapOf( - Byte::class to DataTypes.ByteType, - Short::class to DataTypes.ShortType, - Int::class to DataTypes.IntegerType, - Long::class to DataTypes.LongType, - Boolean::class to DataTypes.BooleanType, - Float::class to DataTypes.FloatType, - Double::class to DataTypes.DoubleType, - String::class to DataTypes.StringType, - LocalDate::class to DataTypes.DateType, - Date::class to DataTypes.DateType, - Timestamp::class to DataTypes.TimestampType, - Instant::class to DataTypes.TimestampType, - ByteArray::class to DataTypes.BinaryType, - Decimal::class to DecimalType.SYSTEM_DEFAULT(), - BigDecimal::class to DecimalType.SYSTEM_DEFAULT(), - BigInteger::class to DecimalType.SYSTEM_DEFAULT(), - CalendarInterval::class to DataTypes.CalendarIntervalType, +import kotlin.reflect.typeOf + +fun kotlinEncoderFor( + kClass: KClass, + arguments: List = emptyList(), + nullable: Boolean = false, + annotations: List = emptyList() +): Encoder = ExpressionEncoder.apply( + KotlinTypeInference.encoderFor( + kClass = kClass, + arguments = arguments, + nullable = nullable, + annotations = annotations, + ) ) /** @@ -133,253 +87,492 @@ private val knownDataTypes: Map, DataType> = mapOf( * @param T type, supported by Spark * @return generated encoder */ -@OptIn(ExperimentalStdlibApi::class) -inline fun encoder(): Encoder = generateEncoder(typeOf(), T::class) +inline fun kotlinEncoderFor(): Encoder = + ExpressionEncoder.apply( + KotlinTypeInference.encoderFor() + ) + +fun kotlinEncoderFor(kType: KType): Encoder = + ExpressionEncoder.apply( + KotlinTypeInference.encoderFor(kType) + ) -/** - * @see encoder - */ -@Suppress("UNCHECKED_CAST") -fun generateEncoder(type: KType, cls: KClass<*>): Encoder { - checkIfEncoderRequiresNewerVersion(cls) - return when { - isSupportedByKotlinClassEncoder(cls) -> kotlinClassEncoder(schema = memoizedSchema(type), kClass = cls) - else -> ENCODERS[cls] as? Encoder? ?: bean(cls.java) - } as Encoder -} -private fun isSupportedByKotlinClassEncoder(cls: KClass<*>): Boolean = - when { - cls == ByteArray::class -> false // uses binary encoder - cls.isData -> true - cls.isSubclassOf(Map::class) -> true - cls.isSubclassOf(Iterable::class) -> true - cls.isSubclassOf(Product::class) -> true - cls.java.isArray -> true - cls.hasAnnotation() -> true - UDTRegistration.exists(cls.jvmName) -> true - else -> false +@Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor()")) +inline fun encoder(): Encoder = kotlinEncoderFor(typeOf()) + +@Deprecated("Use kotlinEncoderFor to get the schema.", ReplaceWith("kotlinEncoderFor().schema()")) +inline fun schema(): DataType = kotlinEncoderFor().schema() + +@Deprecated("Use kotlinEncoderFor to get the schema.", ReplaceWith("kotlinEncoderFor(kType).schema()")) +fun schema(kType: KType): DataType = kotlinEncoderFor(kType).schema() + +object KotlinTypeInference { + + /** + * @param kClass the class for which to infer the encoder. + * @param arguments the generic type arguments for the class. + * @param nullable whether the class is nullable. + * @param annotations the annotations for the class. + * @return an [AgnosticEncoder] for the given class arguments. + */ + fun encoderFor( + kClass: KClass, + arguments: List = emptyList(), + nullable: Boolean = false, + annotations: List = emptyList() + ): AgnosticEncoder = encoderFor( + kType = kClass.createType( + arguments = arguments, + nullable = nullable, + annotations = annotations, + ) + ) + + /** + * @return an [AgnosticEncoder] for the given type [T]. + */ + @JvmName("inlineEncoderFor") + inline fun encoderFor(): AgnosticEncoder = + encoderFor(kType = typeOf()) + + /** + * Main entry function for the inference of encoders. + * + * @return an [AgnosticEncoder] for the given [kType]. + */ + fun encoderFor(kType: KType): AgnosticEncoder = + encoderFor( + currentType = kType, + seenTypeSet = emptySet(), + typeVariables = emptyMap(), + ) as AgnosticEncoder + + + private inline fun KType.isSubtypeOf(): Boolean = isSubtypeOf(typeOf()) + + private val KType.simpleName + get() = toString().removeSuffix("?").removeSuffix("!") + + private fun KType.isDefinedByScalaConstructorParams(): Boolean = when { + isSubtypeOf?>() -> arguments.first().type!!.isDefinedByScalaConstructorParams() + else -> isSubtypeOf() || isSubtypeOf() } + private fun KType.getScalaConstructorParameters( + genericTypeMap: Map, + kClass: KClass<*> = classifier as KClass<*>, + ): List> { + val constructor = + kClass.primaryConstructor + ?: kClass.constructors.firstOrNull() + ?: kClass.staticFunctions.firstOrNull { + it.name == "apply" && it.returnType.classifier == kClass + } + ?: error("couldn't find constructor for $this") + + val kParameters = constructor.parameters + val params = kParameters.map { param -> + val paramType = if (param.type.isSubtypeOf()) { + // Replace value class with underlying type + param.type.getScalaConstructorParameters(genericTypeMap).first().second + } else { + // check if the type was a filled-in generic type, otherwise just use the given type + genericTypeMap[param.type.simpleName] ?: param.type + } -private fun kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder { - val serializer = - if (schema is DataTypeWithClass) serializerFor(kClass.java, schema) - else serializerForType(getType(kClass.java)) + param.name!! to paramType + } - val deserializer = - if (schema is DataTypeWithClass) deserializerFor(kClass.java, schema) - else deserializerForType(getType(kClass.java)) + return params + } - return ExpressionEncoder(serializer, deserializer, ClassTag.apply(kClass.java)) -} + /** + * Can merge two maps transitively. + * This means that given + * ``` + * a: { A -> B, D -> E } + * b: { B -> C, G -> F } + * ``` + * it will return + * ``` + * { A -> C, D -> E, G -> F } + * ``` + * @param valueToKey a function that returns (an optional) key for a given value + */ + private fun transitiveMerge(a: Map, b: Map, valueToKey: (V) -> K?): Map = + a + b.mapValues { a.getOrDefault(valueToKey(it.value), it.value) } + + /** + * + */ + private fun encoderFor( + currentType: KType, + seenTypeSet: Set, + + // how the generic types of the data class (like T, S) are filled in for this instance of the class + typeVariables: Map, + ): AgnosticEncoder<*> { + val kClass = + currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType") + val jClass = kClass.java + + // given t == typeOf>>(), these are [Int, Pair] + val tArguments = currentType.arguments + + // the type arguments of the class, like T, S + val expectedTypeParameters = kClass.typeParameters.map { it } + + @Suppress("NAME_SHADOWING") + val typeVariables = transitiveMerge( + a = typeVariables, + b = (expectedTypeParameters zip tArguments).toMap() + .mapValues { (expectedType, givenType) -> + if (givenType.type != null) return@mapValues givenType.type!! // fill in the type as is + + // when givenType is *, use the upperbound + expectedType.upperBounds.first() + }.mapKeys { it.key.name } + ) { it.simpleName } + + return when { + // primitives java / kotlin + currentType == typeOf() -> AgnosticEncoders.`PrimitiveBooleanEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveByteEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveShortEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveIntEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveLongEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveFloatEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveDoubleEncoder$`.`MODULE$` + + // primitives scala + currentType == typeOf() -> AgnosticEncoders.`PrimitiveBooleanEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveByteEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveShortEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveIntEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveLongEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveFloatEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`PrimitiveDoubleEncoder$`.`MODULE$` + + // boxed primitives java / kotlin + currentType == typeOf() -> AgnosticEncoders.`BoxedBooleanEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedByteEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedShortEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedIntEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedLongEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedFloatEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedDoubleEncoder$`.`MODULE$` + + // boxed primitives scala + currentType == typeOf() -> AgnosticEncoders.`BoxedBooleanEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedByteEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedShortEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedIntEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedLongEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedFloatEncoder$`.`MODULE$` + currentType == typeOf() -> AgnosticEncoders.`BoxedDoubleEncoder$`.`MODULE$` + + // leaf encoders + currentType.isSubtypeOf() -> AgnosticEncoders.`StringEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> AgnosticEncoders.DEFAULT_SPARK_DECIMAL_ENCODER() + currentType.isSubtypeOf() -> AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER() + currentType.isSubtypeOf() -> AgnosticEncoders.`ScalaBigIntEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> AgnosticEncoders.`BinaryEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER() + currentType.isSubtypeOf() -> AgnosticEncoders.`JavaBigIntEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> AgnosticEncoders.`CalendarIntervalEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER() + currentType.isSubtypeOf() -> TODO("User java.time.LocalDate for now.") + currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_DATE_ENCODER() + currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_INSTANT_ENCODER() + currentType.isSubtypeOf() -> TODO("Use java.time.Instant for now.") + currentType.isSubtypeOf() -> TODO("Use java.time.Instant for now.") + currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_TIMESTAMP_ENCODER() + currentType.isSubtypeOf() -> AgnosticEncoders.`LocalDateTimeEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> TODO("Use java.time.LocalDateTime for now.") + currentType.isSubtypeOf() -> AgnosticEncoders.`DayTimeIntervalEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> TODO("Use java.time.Duration for now.") + currentType.isSubtypeOf() -> AgnosticEncoders.`YearMonthIntervalEncoder$`.`MODULE$` + currentType.isSubtypeOf() -> TODO("Use java.time.Period for now.") + currentType.isSubtypeOf() -> TODO("Use java.time.Period for now.") + currentType.isSubtypeOf() -> AgnosticEncoders.`UnboundRowEncoder$`.`MODULE$` + + // enums + kClass.isSubclassOf(Enum::class) -> AgnosticEncoders.JavaEnumEncoder(ClassTag.apply(jClass)) + + // TODO test + kClass.isSubclassOf(scala.Enumeration.Value::class) -> + AgnosticEncoders.ScalaEnumEncoder(jClass.superclass, ClassTag.apply(jClass)) + + // udts + currentType.hasAnnotation() -> { + val annotation = jClass.getAnnotation(SQLUserDefinedType::class.java)!! + val udtClass = annotation.udt + val udt = udtClass.primaryConstructor!!.call() + AgnosticEncoders.UDTEncoder(udt, udtClass.java) + } -/** - * Not meant to be used by the user explicitly. - * - * This function generates the DataType schema for supported classes, including Kotlin data classes, [Map], - * [Iterable], [Product], [Array], and combinations of those. - * - * It's mainly used by [generateEncoder]/[encoder]. - */ -@OptIn(ExperimentalStdlibApi::class) -fun schema(type: KType, map: Map = mapOf()): DataType { - val primitiveSchema = knownDataTypes[type.classifier] - if (primitiveSchema != null) - return KSimpleTypeWrapper( - /* dt = */ primitiveSchema, - /* cls = */ (type.classifier!! as KClass<*>).java, - /* nullable = */ type.isMarkedNullable - ) + UDTRegistration.exists(kClass.jvmName) -> { + val udt = UDTRegistration.getUDTFor(kClass.jvmName)!! + .get()!! + .getConstructor() + .newInstance() as UserDefinedType<*> - val klass = type.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $type") - val args = type.arguments + AgnosticEncoders.UDTEncoder(udt, udt.javaClass) + } - val types = transitiveMerge( - map, - klass.typeParameters.zip(args).associate { - it.first.name to it.second.type!! - }, - ) + currentType.isSubtypeOf>() -> { + val elementEncoder = encoderFor( + currentType = tArguments.first().type!!, + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.OptionEncoder(elementEncoder) + } - return when { - klass.isSubclassOf(Enum::class) -> - KSimpleTypeWrapper( - /* dt = */ DataTypes.StringType, - /* cls = */ klass.java, - /* nullable = */ type.isMarkedNullable - ) - - klass.isSubclassOf(Iterable::class) || klass.java.isArray -> { - val listParam = if (klass.java.isArray) { - when (klass) { - IntArray::class -> typeOf() - LongArray::class -> typeOf() - FloatArray::class -> typeOf() - DoubleArray::class -> typeOf() - BooleanArray::class -> typeOf() - ShortArray::class -> typeOf() - /* ByteArray handled by BinaryType */ - else -> types.getValue(klass.typeParameters[0].name) - } - } else types.getValue(klass.typeParameters[0].name) - - val dataType = DataTypes.createArrayType( - /* elementType = */ schema(listParam, types), - /* containsNull = */ listParam.isMarkedNullable - ) - - KComplexTypeWrapper( - /* dt = */ dataType, - /* cls = */ klass.java, - /* nullable = */ type.isMarkedNullable - ) - } + // primitive arrays + currentType.isSubtypeOf() -> { + val elementEncoder = encoderFor( + currentType = typeOf(), + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.ArrayEncoder(elementEncoder, false) + } - klass == Map::class -> { - val mapKeyParam = types.getValue(klass.typeParameters[0].name) - val mapValueParam = types.getValue(klass.typeParameters[1].name) - - val dataType = DataTypes.createMapType( - /* keyType = */ schema(mapKeyParam, types), - /* valueType = */ schema(mapValueParam, types), - /* valueContainsNull = */ true - ) - - KComplexTypeWrapper( - /* dt = */ dataType, - /* cls = */ klass.java, - /* nullable = */ type.isMarkedNullable - ) - } + currentType.isSubtypeOf() -> { + val elementEncoder = encoderFor( + currentType = typeOf(), + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.ArrayEncoder(elementEncoder, false) + } - klass.isData -> { - - val structType = StructType( - klass - .primaryConstructor!! - .parameters - .filter { it.findAnnotation() == null } - .map { - val projectedType = types[it.type.toString()] ?: it.type - - val readMethodName = when { - it.name!!.startsWith("is") -> it.name!! - else -> "get${it.name!!.replaceFirstChar { it.uppercase() }}" - } - - val propertyDescriptor = PropertyDescriptor( - /* propertyName = */ it.name, - /* beanClass = */ klass.java, - /* readMethodName = */ readMethodName, - /* writeMethodName = */ null - ) - - KStructField( - /* getterName = */ propertyDescriptor.readMethod.name, - /* delegate = */ StructField( - /* name = */ it.name, - /* dataType = */ schema(projectedType, types), - /* nullable = */ projectedType.isMarkedNullable, - /* metadata = */ Metadata.empty() - ) - ) - } - .toTypedArray() - ) - KDataTypeWrapper(structType, klass.java, true) - } - klass.isSubclassOf(Product::class) -> { + currentType.isSubtypeOf() -> { + val elementEncoder = encoderFor( + currentType = typeOf(), + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.ArrayEncoder(elementEncoder, false) + } - // create map from T1, T2 to Int, String etc. - val typeMap = klass.constructors.first().typeParameters.map { it.name } - .zip( - type.arguments.map { it.type } + currentType.isSubtypeOf() -> { + val elementEncoder = encoderFor( + currentType = typeOf(), + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, ) - .toMap() + AgnosticEncoders.ArrayEncoder(elementEncoder, false) + } - // collect params by name and actual type - val params = klass.constructors.first().parameters.map { - val typeName = it.type.toString().replace("!", "") - it.name to (typeMap[typeName] ?: it.type) + currentType.isSubtypeOf() -> { + val elementEncoder = encoderFor( + currentType = typeOf(), + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.ArrayEncoder(elementEncoder, false) } - val structType = DataTypes.createStructType( - params.map { (fieldName, fieldType) -> - val dataType = schema(fieldType, types) - - KStructField( - /* getterName = */ fieldName, - /* delegate = */ StructField( - /* name = */ fieldName, - /* dataType = */ dataType, - /* nullable = */ fieldType.isMarkedNullable, - /* metadata = */Metadata.empty() - ) - ) - }.toTypedArray() - ) - - KComplexTypeWrapper( - /* dt = */ structType, - /* cls = */ klass.java, - /* nullable = */ true - ) - } + currentType.isSubtypeOf() -> { + val elementEncoder = encoderFor( + currentType = typeOf(), + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.ArrayEncoder(elementEncoder, false) + } - UDTRegistration.exists(klass.jvmName) -> { - @Suppress("UNCHECKED_CAST") - val dataType = UDTRegistration.getUDTFor(klass.jvmName) - .getOrNull()!! - .let { it as Class> } - .getConstructor() - .newInstance() - - KSimpleTypeWrapper( - /* dt = */ dataType, - /* cls = */ klass.java, - /* nullable = */ type.isMarkedNullable - ) - } + // boxed arrays + jClass.isArray -> { + val type = currentType.arguments.first().type!! + val elementEncoder = encoderFor( + currentType = type.withNullability(true), // so we get a boxed array + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.ArrayEncoder(elementEncoder, true) + } - klass.hasAnnotation() -> { - val dataType = klass.findAnnotation()!! - .udt - .java - .getConstructor() - .newInstance() - - KSimpleTypeWrapper( - /* dt = */ dataType, - /* cls = */ klass.java, - /* nullable = */ type.isMarkedNullable - ) - } + currentType.isSubtypeOf?>() -> { + val subType = tArguments.first().type!! + val elementEncoder = encoderFor( + currentType = subType, + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.IterableEncoder, _>( + /* clsTag = */ ClassTag.apply(jClass), + /* element = */ elementEncoder, + /* containsNull = */ subType.isMarkedNullable, + /* lenientSerialization = */ false, + ) + } - else -> throw IllegalArgumentException("$type is unsupported") - } -} + currentType.isSubtypeOf?>() -> { + val subType = tArguments.first().type!! + val elementEncoder = encoderFor( + currentType = subType, + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.IterableEncoder, _>( + /* clsTag = */ ClassTag.apply(jClass), + /* element = */ elementEncoder, + /* containsNull = */ subType.isMarkedNullable, + /* lenientSerialization = */ false, + ) + } -/** - * Memoized version of [schema]. This ensures the [DataType] of given `type` only - * has to be calculated once. - */ -private val memoizedSchema: (type: KType) -> DataType = memoize { - schema(it) -} + currentType.isSubtypeOf?>() -> { + val subType = tArguments.first().type!! + val elementEncoder = encoderFor( + currentType = subType, + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.IterableEncoder, _>( + /* clsTag = */ ClassTag.apply(jClass), + /* element = */ elementEncoder, + /* containsNull = */ subType.isMarkedNullable, + /* lenientSerialization = */ false, + ) + } + + currentType.isSubtypeOf?>() -> { + val subType = tArguments.first().type!! + val elementEncoder = encoderFor( + currentType = subType, + seenTypeSet = seenTypeSet, + typeVariables = typeVariables, + ) + AgnosticEncoders.IterableEncoder, _>( + /* clsTag = */ ClassTag.apply(jClass), + /* element = */ elementEncoder, + /* containsNull = */ subType.isMarkedNullable, + /* lenientSerialization = */ false, + ) + } -private fun transitiveMerge(a: Map, b: Map): Map = - a + b.mapValues { a.getOrDefault(it.value.toString(), it.value) } + currentType.isSubtypeOf?>() -> TODO() + currentType.isSubtypeOf?>() -> TODO() -/** Wrapper around function with 1 argument to avoid recalculation when a certain argument is queried again. */ -private class Memoize1(private val function: (T) -> R) : (T) -> R { - private val values = ConcurrentHashMap() - override fun invoke(x: T): R = values.getOrPut(x) { function(x) } -} + kClass.isData -> { + if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType") + val constructor = kClass.primaryConstructor!! + val kParameters = constructor.parameters + // todo filter for transient? + + val props = kParameters.map { + kClass.declaredMemberProperties.find { prop -> prop.name == it.name }!! + } + + val params = (kParameters zip props).map { (param, prop) -> + // check if the type was a filled-in generic type, otherwise just use the given type + val paramType = typeVariables[param.type.simpleName] ?: param.type + val encoder = encoderFor( + currentType = paramType, + seenTypeSet = seenTypeSet + currentType, + typeVariables = typeVariables, + ) + val paramName = param.name!! + val readMethodName = prop.getter.javaMethod!!.name + val writeMethodName = (prop as? KMutableProperty<*>)?.setter?.javaMethod?.name + + DirtyProductEncoderField( + name = paramName, + readMethodName = readMethodName, + writeMethodName = writeMethodName, + encoder = encoder, + nullable = paramType.isMarkedNullable, + ) + } + ProductEncoder( + /* clsTag = */ ClassTag.apply(jClass), + /* fields = */ params.asScalaSeq(), + /* outerPointerGetter = */ OuterScopes.getOuterScope(jClass).toOption(), + ) + } + + currentType.isDefinedByScalaConstructorParams() -> { + if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType") + val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass) -/** Wrapper around function to avoid recalculation when a certain argument is queried again. */ -private fun ((T) -> R).memoized(): (T) -> R = Memoize1(this) + val params: List = constructorParams.map { (paramName, paramType) -> + val encoder = encoderFor( + currentType = paramType, + seenTypeSet = seenTypeSet + currentType, + typeVariables = typeVariables, + ) + AgnosticEncoders.EncoderField( + /* name = */ paramName, + /* enc = */ encoder, + /* nullable = */ paramType.isMarkedNullable, + /* metadata = */ Metadata.empty(), + /* readMethod = */ paramName.toOption(), + /* writeMethod = */ null.toOption(), + ) + } + ProductEncoder( + /* clsTag = */ ClassTag.apply(jClass), + /* fields = */ params.asScalaSeq(), + /* outerPointerGetter = */ OuterScopes.getOuterScope(jClass).toOption(), + ) + } + + // java bean class +// currentType.classifier is KClass<*> -> { +// TODO() +// +// JavaBeanEncoder() +// } -/** Wrapper around function to avoid recalculation when a certain argument is queried again. */ -private fun memoize(function: (T) -> R): (T) -> R = Memoize1(function) + else -> throw IllegalArgumentException("No encoder found for type $currentType") + } + } +} + +internal open class DirtyProductEncoderField( + private val name: String, // the name used for the column + private val readMethodName: String, // the name of the method used to read the value + private val writeMethodName: String?, + encoder: AgnosticEncoder<*>, + nullable: Boolean, + metadata: Metadata = Metadata.empty(), +) : AgnosticEncoders.EncoderField( + /* name = */ readMethodName, + /* enc = */ encoder, + /* nullable = */ nullable, + /* metadata = */ metadata, + /* readMethod = */ readMethodName.toOption(), + /* writeMethod = */ writeMethodName.toOption(), +), Serializable { + + private var i = 0 + + /** + * This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder] + * creates an [Invoke] using [name] first and then calls [name] again to retrieve + * the name of the column. This way, we can alternate between the two names. + */ + override fun name(): String = + if (i++ % 2 == 0) readMethodName + else name + + override fun canEqual(that: Any?): Boolean = that is AgnosticEncoders.EncoderField + + override fun productElement(n: Int): Any = + when (n) { + 0 -> readMethodName // so it doesn't affect name() + 1 -> enc() + 2 -> nullable() + 3 -> metadata() + 4 -> readMethod() + 5 -> writeMethod() + else -> throw IndexOutOfBoundsException() + } + override fun productArity(): Int = 6 +} \ No newline at end of file