Skip to content

Commit

Permalink
only remove ir functions registered in that execute request
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Nov 4, 2024
1 parent 220c123 commit cbe8467
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
11 changes: 7 additions & 4 deletions hail/src/main/scala/is/hail/backend/BackendRpc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package is.hail.backend

import is.hail.expr.ir.IRParser
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey
import is.hail.io.BufferSpec
import is.hail.io.plink.LoadPlink
import is.hail.io.vcf.LoadVCF
import is.hail.services.retryTransientErrors
import is.hail.types.virtual.{Kind, TFloat64, VType}
import is.hail.types.virtual.Kinds._
import is.hail.utils.{using, ExecutionTimer}
import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer}
import is.hail.utils.ExecutionTimer.Timings
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -177,9 +178,10 @@ trait BackendRpc {
)(
body: => A
): A = {
val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length)
try {
serializedFunctions.foreach { func =>
IRFunctionRegistry.registerIR(
for (func <- serializedFunctions) {
fns += IRFunctionRegistry.registerIR(
ctx,
func.name,
func.type_parameters,
Expand All @@ -192,7 +194,8 @@ trait BackendRpc {

body
} finally
IRFunctionRegistry.clearUserFunctions()
for (i <- 0 until fns.length)
IRFunctionRegistry.unregisterIr(fns(i))
}
}

Expand Down
58 changes: 43 additions & 15 deletions hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import scala.reflect._
import org.apache.spark.sql.Row

object IRFunctionRegistry {
private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] =
type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type]))

private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] =
mutable.HashSet.empty

def clearUserFunctions(): Unit = {
Expand Down Expand Up @@ -69,25 +71,41 @@ object IRFunctionRegistry {
typeParamStrs: Array[String],
argNameStrs: Array[String],
argTypeStrs: Array[String],
returnType: String,
returnTypeStr: String,
bodyStr: String,
): Unit = {
): UserDefinedFnKey = {
requireJavaIdentifier(name)
val argNames = argNameStrs.map(Name)
val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq
val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq
val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*)
val body = IRParser.parse_value_ir(ctx, bodyStr, refMap)
val argNames = argNameStrs.map(Name)

val body =
IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*))
val returnType = IRParser.parseType(returnTypeStr)
assert(body.typ == returnType)

userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes)))
val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes))
userAddedFunctions += key
addIR(
name,
typeParameters,
valueParameterTypes,
IRParser.parseType(returnType),
returnType,
false,
(_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)),
)
key
}

def unregisterIr(key: UserDefinedFnKey): Unit = {
val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key
if (userAddedFunctions.remove(key))
removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes)
else {
throw new NoSuchElementException(
s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}"
)
}
}

def removeIRFunction(
Expand All @@ -112,7 +130,9 @@ object IRFunctionRegistry {
case Seq() => None
case Seq(f) => Some(f)
case _ =>
fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).")
fatal(
s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}."
)
}

def lookupFunctionOrFail(
Expand All @@ -124,28 +144,34 @@ object IRFunctionRegistry {
jvmRegistry.lift(name) match {
case None =>
fatal(
s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType"
s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}."
)
case Some(functions) =>
functions.filter(t =>
t.unify(typeParameters, valueParameterTypes, returnType)
).toSeq match {
case Seq() =>
val prettyFunctionSignature =
s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType"
val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n")
fatal(
s"No function found with the signature $prettyFunctionSignature.\n" +
s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" +
s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures"
)
case Seq(f) => f
case _ => fatal(
s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})."
s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})."
)
}
}
}

private[this] def prettyFunctionSignature(
name: String,
returnType: Type,
typeParameterTypes: Seq[Type],
valueParameterTypes: Seq[Type],
): String =
s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType"

def lookupIR(
name: String,
returnType: Type,
Expand All @@ -165,7 +191,9 @@ object IRFunctionRegistry {
case Seq() => None
case Seq(kv) => Some(kv)
case _ =>
fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).")
fatal(
s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}."
)
}
}

Expand Down

0 comments on commit cbe8467

Please sign in to comment.