Skip to content

Commit

Permalink
enabled core as scala-helpers with VarargUnwrapper. Removed name hack…
Browse files Browse the repository at this point in the history
… in favor of upcoming IR compiler plugin. Removed spark dependency in scala-helpers
  • Loading branch information
Jolanrensen committed Mar 20, 2024
1 parent 2c875ff commit d60e4dc
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,20 @@ package org.jetbrains.kotlinx.spark.api
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.SerializerBuildHelper
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.EncoderField
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.Metadata
import org.apache.spark.sql.types.SQLUserDefinedType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.unsafe.types.CalendarInterval
import scala.reflect.ClassTag
import java.io.Serializable
import kotlin.reflect.KClass
import kotlin.reflect.KMutableProperty
import kotlin.reflect.KType
Expand Down Expand Up @@ -113,11 +111,13 @@ private fun <T> applyEncoder(agnosticEncoder: AgnosticEncoder<T>): Encoder<T> {
@Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor<T>()"))
inline fun <reified T> encoder(): Encoder<T> = kotlinEncoderFor(typeOf<T>())

@Deprecated("Use kotlinEncoderFor to get the schema.", ReplaceWith("kotlinEncoderFor<T>().schema()"))
inline fun <reified T> schema(): DataType = kotlinEncoderFor<T>().schema()
internal fun StructType.unwrap(): DataType =
if (fields().singleOrNull()?.name() == "value") fields().single().dataType()
else this

@Deprecated("Use kotlinEncoderFor to get the schema.", ReplaceWith("kotlinEncoderFor<Any?>(kType).schema()"))
fun schema(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema()
inline fun <reified T> schemaFor(): DataType = schemaFor(typeOf<T>())

fun schemaFor(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema().unwrap()

object KotlinTypeInference {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class UDFWrapper0(private val udfName: String) {
@OptIn(ExperimentalStdlibApi::class)
@Deprecated("Use new UDF notation", ReplaceWith("this.register(name, func)"), DeprecationLevel.HIDDEN)
inline fun <reified R> UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 {
register(name, UDF0(func), kotlinEncoderFor<R>().schema())
register(name, UDF0(func), schemaFor<R>())
return UDFWrapper0(name)
}

Expand All @@ -78,7 +78,7 @@ class UDFWrapper1(private val udfName: String) {
@Deprecated("Use new UDF notation", ReplaceWith("this.register(name, func)"), DeprecationLevel.HIDDEN)
inline fun <reified T0, reified R> UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 {
T0::class.checkForValidType("T0")
register(name, UDF1(func), kotlinEncoderFor<R>().schema())
register(name, UDF1(func), schemaFor<R>())
return UDFWrapper1(name)
}

Expand Down Expand Up @@ -107,7 +107,7 @@ inline fun <reified T0, reified T1, reified R> UDFRegistration.register(
): UDFWrapper2 {
T0::class.checkForValidType("T0")
T1::class.checkForValidType("T1")
register(name, UDF2(func), kotlinEncoderFor<R>().schema())
register(name, UDF2(func), schemaFor<R>())
return UDFWrapper2(name)
}

Expand Down Expand Up @@ -137,7 +137,7 @@ inline fun <reified T0, reified T1, reified T2, reified R> UDFRegistration.regis
T0::class.checkForValidType("T0")
T1::class.checkForValidType("T1")
T2::class.checkForValidType("T2")
register(name, UDF3(func), kotlinEncoderFor<R>().schema())
register(name, UDF3(func), schemaFor<R>())
return UDFWrapper3(name)
}

Expand Down Expand Up @@ -168,7 +168,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified R> UDFRegist
T1::class.checkForValidType("T1")
T2::class.checkForValidType("T2")
T3::class.checkForValidType("T3")
register(name, UDF4(func), kotlinEncoderFor<R>().schema())
register(name, UDF4(func), schemaFor<R>())
return UDFWrapper4(name)
}

Expand Down Expand Up @@ -200,7 +200,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T2::class.checkForValidType("T2")
T3::class.checkForValidType("T3")
T4::class.checkForValidType("T4")
register(name, UDF5(func), kotlinEncoderFor<R>().schema())
register(name, UDF5(func), schemaFor<R>())
return UDFWrapper5(name)
}

Expand Down Expand Up @@ -240,7 +240,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T3::class.checkForValidType("T3")
T4::class.checkForValidType("T4")
T5::class.checkForValidType("T5")
register(name, UDF6(func), kotlinEncoderFor<R>().schema())
register(name, UDF6(func), schemaFor<R>())
return UDFWrapper6(name)
}

Expand Down Expand Up @@ -282,7 +282,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T4::class.checkForValidType("T4")
T5::class.checkForValidType("T5")
T6::class.checkForValidType("T6")
register(name, UDF7(func), kotlinEncoderFor<R>().schema())
register(name, UDF7(func), schemaFor<R>())
return UDFWrapper7(name)
}

Expand Down Expand Up @@ -326,7 +326,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T5::class.checkForValidType("T5")
T6::class.checkForValidType("T6")
T7::class.checkForValidType("T7")
register(name, UDF8(func), kotlinEncoderFor<R>().schema())
register(name, UDF8(func), schemaFor<R>())
return UDFWrapper8(name)
}

Expand Down Expand Up @@ -372,7 +372,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T6::class.checkForValidType("T6")
T7::class.checkForValidType("T7")
T8::class.checkForValidType("T8")
register(name, UDF9(func), kotlinEncoderFor<R>().schema())
register(name, UDF9(func), schemaFor<R>())
return UDFWrapper9(name)
}

Expand Down Expand Up @@ -432,7 +432,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T7::class.checkForValidType("T7")
T8::class.checkForValidType("T8")
T9::class.checkForValidType("T9")
register(name, UDF10(func), kotlinEncoderFor<R>().schema())
register(name, UDF10(func), schemaFor<R>())
return UDFWrapper10(name)
}

Expand Down Expand Up @@ -495,7 +495,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T8::class.checkForValidType("T8")
T9::class.checkForValidType("T9")
T10::class.checkForValidType("T10")
register(name, UDF11(func), kotlinEncoderFor<R>().schema())
register(name, UDF11(func), schemaFor<R>())
return UDFWrapper11(name)
}

Expand Down Expand Up @@ -561,7 +561,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T9::class.checkForValidType("T9")
T10::class.checkForValidType("T10")
T11::class.checkForValidType("T11")
register(name, UDF12(func), kotlinEncoderFor<R>().schema())
register(name, UDF12(func), schemaFor<R>())
return UDFWrapper12(name)
}

Expand Down Expand Up @@ -630,7 +630,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T10::class.checkForValidType("T10")
T11::class.checkForValidType("T11")
T12::class.checkForValidType("T12")
register(name, UDF13(func), kotlinEncoderFor<R>().schema())
register(name, UDF13(func), schemaFor<R>())
return UDFWrapper13(name)
}

Expand Down Expand Up @@ -702,7 +702,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T11::class.checkForValidType("T11")
T12::class.checkForValidType("T12")
T13::class.checkForValidType("T13")
register(name, UDF14(func), kotlinEncoderFor<R>().schema())
register(name, UDF14(func), schemaFor<R>())
return UDFWrapper14(name)
}

Expand Down Expand Up @@ -777,7 +777,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T12::class.checkForValidType("T12")
T13::class.checkForValidType("T13")
T14::class.checkForValidType("T14")
register(name, UDF15(func), kotlinEncoderFor<R>().schema())
register(name, UDF15(func), schemaFor<R>())
return UDFWrapper15(name)
}

Expand Down Expand Up @@ -855,7 +855,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T13::class.checkForValidType("T13")
T14::class.checkForValidType("T14")
T15::class.checkForValidType("T15")
register(name, UDF16(func), kotlinEncoderFor<R>().schema())
register(name, UDF16(func), schemaFor<R>())
return UDFWrapper16(name)
}

Expand Down Expand Up @@ -936,7 +936,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T14::class.checkForValidType("T14")
T15::class.checkForValidType("T15")
T16::class.checkForValidType("T16")
register(name, UDF17(func), kotlinEncoderFor<R>().schema())
register(name, UDF17(func), schemaFor<R>())
return UDFWrapper17(name)
}

Expand Down Expand Up @@ -1020,7 +1020,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T15::class.checkForValidType("T15")
T16::class.checkForValidType("T16")
T17::class.checkForValidType("T17")
register(name, UDF18(func), kotlinEncoderFor<R>().schema())
register(name, UDF18(func), schemaFor<R>())
return UDFWrapper18(name)
}

Expand Down Expand Up @@ -1107,7 +1107,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T16::class.checkForValidType("T16")
T17::class.checkForValidType("T17")
T18::class.checkForValidType("T18")
register(name, UDF19(func), kotlinEncoderFor<R>().schema())
register(name, UDF19(func), schemaFor<R>())
return UDFWrapper19(name)
}

Expand Down Expand Up @@ -1197,7 +1197,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T17::class.checkForValidType("T17")
T18::class.checkForValidType("T18")
T19::class.checkForValidType("T19")
register(name, UDF20(func), kotlinEncoderFor<R>().schema())
register(name, UDF20(func), schemaFor<R>())
return UDFWrapper20(name)
}

Expand Down Expand Up @@ -1290,7 +1290,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T18::class.checkForValidType("T18")
T19::class.checkForValidType("T19")
T20::class.checkForValidType("T20")
register(name, UDF21(func), kotlinEncoderFor<R>().schema())
register(name, UDF21(func), schemaFor<R>())
return UDFWrapper21(name)
}

Expand Down Expand Up @@ -1386,6 +1386,6 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
T19::class.checkForValidType("T19")
T20::class.checkForValidType("T20")
T21::class.checkForValidType("T21")
register(name, UDF22(func), kotlinEncoderFor<R>().schema())
register(name, UDF22(func), schemaFor<R>())
return UDFWrapper22(name)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.spark.api

import org.apache.spark.sql.*
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructType
import scala.collection.Seq
import java.io.Serializable
import kotlin.reflect.KClass
Expand All @@ -31,6 +32,7 @@ import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.primaryConstructor
import org.apache.spark.sql.expressions.UserDefinedFunction as SparkUserDefinedFunction


/**
* Checks if [this] is of a valid type for a UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ inline fun <reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ByteArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ByteArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -334,7 +334,7 @@ inline fun <reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ShortArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ShortArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -533,7 +533,7 @@ inline fun <reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> IntArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> IntArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -732,7 +732,7 @@ inline fun <reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> LongArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> LongArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -931,7 +931,7 @@ inline fun <reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> FloatArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> FloatArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -1130,7 +1130,7 @@ inline fun <reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> DoubleArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> DoubleArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -1325,11 +1325,9 @@ inline fun <reified R> udf(
nondeterministic: Boolean = false,
varargFunc: UDF1<BooleanArray, R>,
): UserDefinedFunctionVararg<Boolean, R> {


return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> BooleanArray(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> BooleanArray(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down Expand Up @@ -1528,7 +1526,7 @@ inline fun <reified T, reified R> udf(

return withAllowUntypedScalaUDF {
UserDefinedFunctionVararg(
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> Array<T>(i, init::call) }, kotlinEncoderFor<R>().schema())
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> Array<T>(i, init::apply) }, schemaFor<R>())
.let { if (nondeterministic) it.asNondeterministic() else it }
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
encoder = kotlinEncoderFor<R>(),
Expand Down
Loading

0 comments on commit d60e4dc

Please sign in to comment.