Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Expose references via ExecuteContext #14686

Open
wants to merge 5 commits into
base: ehigham/cloud-credentials
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:

def persist_expression(self, expr):
t = expr.dtype
return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t)
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def set_flags(self, **flags: Mapping[str, str]):
available = self._jbackend.availableFlags()
available = self._jbackend.pyAvailableFlags()
invalid = []
for flag, value in flags.items():
if flag in available:
self._jbackend.setFlag(flag, value)
self._jbackend.pySetFlag(flag, value)
else:
invalid.append(flag)
if len(invalid) != 0:
Expand All @@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]):
)

def get_flags(self, *flags) -> Mapping[str, str]:
return {flag: self._jbackend.getFlag(flag) for flag in flags}
return {flag: self._jbackend.pyGetFlag(flag) for flag in flags}

def _add_reference_to_scala_backend(self, rg):
self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8'))
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3880,7 +3880,7 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)


class JavaIR(IR):
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,4 +1215,4 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class HailFeatureFlags private (
flags.update(flag, value)
}

def +(feature: (String, String)): HailFeatureFlags =
new HailFeatureFlags(flags + (feature._1 -> feature._2))

def get(flag: String): String = flags(flag)

def lookup(flag: String): Option[String] =
Expand Down
100 changes: 16 additions & 84 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package is.hail.backend

import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
}
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand All @@ -20,7 +20,6 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

Expand All @@ -29,7 +28,7 @@ import java.nio.charset.StandardCharsets

import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

object Backend {
Expand All @@ -41,13 +40,6 @@ object Backend {
s"hail_query_$id"
}

private var irID: Int = 0

def nextIRID(): Int = {
irID += 1
irID
}

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
Expand All @@ -66,6 +58,9 @@ object Backend {
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
}

abstract class BroadcastValue[T] { def value: T }
Expand All @@ -89,14 +84,6 @@ abstract class Backend extends Closeable {

val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()

protected[this] def addJavaIR(ir: BaseIR): Int = {
val id = Backend.nextIRID()
persistedIR += (id -> ir)
id
}

def removeJavaIR(id: Int): Unit = persistedIR.remove(id)

def defaultParallelism: Int

def canExecuteParallelTasksOnDriver: Boolean = true
Expand Down Expand Up @@ -133,31 +120,6 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

var references: Map[String, ReferenceGenome] = Map.empty

def addDefaultReferences(): Unit =
references = ReferenceGenome.builtinReferences()

def addReference(rg: ReferenceGenome): Unit = {
references.get(rg.name) match {
case Some(rg2) =>
if (rg != rg2) {
fatal(
s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " +
s"@1",
references.keys.truncatable("\n "),
)
}
case None =>
references += (rg.name -> rg)
}
}

def hasReference(name: String) = references.contains(name)

def removeReference(name: String): Unit =
references -= name

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -191,9 +153,6 @@ abstract class Backend extends Closeable {

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

private[this] def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
Expand Down Expand Up @@ -222,15 +181,7 @@ abstract class Backend extends Closeable {
}
}

def loadReferencesFromDataset(path: String): Array[Byte] = {
withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

implicit val formats: Formats = defaultJSONFormats
Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8)
}
}
def loadReferencesFromDataset(path: String): Array[Byte]

def fromFASTAFile(
name: String,
Expand All @@ -242,18 +193,20 @@ abstract class Backend extends Closeable {
parInput: Array[String],
): Array[Byte] =
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput)
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
jsonToBytes {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
def parseVCFMetadata(path: String): Array[Byte] =
withExecuteContext { ctx =>
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
jsonToBytes {
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
}
}
}

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
Expand All @@ -263,27 +216,6 @@ abstract class Backend extends Closeable {
)
}

def pyRegisterIR(
name: String,
typeParamStrs: java.util.ArrayList[String],
argNameStrs: java.util.ArrayList[String],
argTypeStrs: java.util.ArrayList[String],
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
typeParamStrs.asScala.toArray,
argNameStrs.asScala.toArray,
argTypeStrs.asScala.toArray,
returnType,
bodyStr,
)
}
}

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
}
return
}

val response: Array[Byte] = exchange.getRequestURI.getPath match {
case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir)
case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir)
Expand Down
9 changes: 7 additions & 2 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object ExecuteContext {
tmpdir: String,
localTmpdir: String,
backend: Backend,
references: Map[String, ReferenceGenome],
fs: FS,
timer: ExecutionTimer,
tempFileManager: TempFileManager,
Expand All @@ -79,6 +80,7 @@ object ExecuteContext {
tmpdir,
localTmpdir,
backend,
references,
fs,
region,
timer,
Expand Down Expand Up @@ -107,6 +109,7 @@ class ExecuteContext(
val tmpdir: String,
val localTmpdir: String,
val backend: Backend,
val references: Map[String, ReferenceGenome],
val fs: FS,
val r: Region,
val timer: ExecutionTimer,
Expand All @@ -128,7 +131,7 @@ class ExecuteContext(
)
}

val stateManager = HailStateManager(backend.references)
val stateManager = HailStateManager(references)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand All @@ -154,7 +157,7 @@ class ExecuteContext(

def getFlag(name: String): String = flags.get(name)

def getReference(name: String): ReferenceGenome = backend.references(name)
def getReference(name: String): ReferenceGenome = references(name)

def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null

Expand All @@ -174,6 +177,7 @@ class ExecuteContext(
tmpdir: String = this.tmpdir,
localTmpdir: String = this.localTmpdir,
backend: Backend = this.backend,
references: Map[String, ReferenceGenome] = this.references,
fs: FS = this.fs,
r: Region = this.r,
timer: ExecutionTimer = this.timer,
Expand All @@ -189,6 +193,7 @@ class ExecuteContext(
tmpdir,
localTmpdir,
backend,
references,
fs,
r,
timer,
Expand Down
Loading