diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index e0b448bef2d..8f04b754f85 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -4,8 +4,7 @@ import is.hail.asm4s._ import is.hail.backend.Backend.jsonToBytes import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ - BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses, - SortField, TableIR, TableReader, + BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader, } import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} @@ -107,9 +106,6 @@ abstract class Backend extends Closeable { def shouldCacheQueryInfo: Boolean = true - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] - def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, @@ -208,23 +204,3 @@ abstract class Backend extends Closeable { def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } - -trait BackendWithCodeCache { - private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50) - - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] = { - codeCache.get(k) match { - case Some(v) => v.asInstanceOf[CompiledFunction[T]] - case None => - val compiledFunction = f - codeCache += ((k, compiledFunction)) - compiledFunction - } - } -} - -trait BackendWithNoCodeCache { - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] = f -} diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 245ec4d457f..f1f0abe4860 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext +import is.hail.expr.ir.{CodeCacheKey, CompiledFunction} import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS import is.hail.linalg.BlockMatrix @@ -73,6 +74,7 @@ object ExecuteContext { backendContext: BackendContext, irMetadata: IrMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix], + codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], )( f: ExecuteContext => T ): T = { @@ -92,6 +94,7 @@ object ExecuteContext { backendContext, irMetadata, blockMatrixCache, + codeCache, ))(f(_)) } } @@ -122,6 +125,7 @@ class ExecuteContext( val backendContext: BackendContext, val irMetadata: IrMetadata, val BlockMatrixCache: mutable.Map[String, BlockMatrix], + val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], ) extends Closeable { val rngNonce: Long = @@ -191,6 +195,7 @@ class ExecuteContext( backendContext: BackendContext = this.backendContext, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, + codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, )( f: ExecuteContext => A ): A = @@ -208,5 +213,6 @@ class ExecuteContext( backendContext, irMetadata, blockMatrixCache, + codeCache, ))(f) } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 7a34954d243..10498f6debf 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.types._ @@ -70,13 +71,14 @@ object LocalBackend { class LocalBackend( val tmpdir: String, override val references: mutable.Map[String, ReferenceGenome], -) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { +) extends Backend with Py4JBackendExtensions { override def backend: Backend = this override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() override def longLifeTempFileManager: TempFileManager = null - private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) + private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -100,6 +102,7 @@ class LocalBackend( }, new IrMetadata(), ImmutableMap.empty, + codeCache, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 62c6099dc45..cb4fd37a901 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -6,10 +6,10 @@ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate import is.hail.expr.ir.{ - Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, - TypeCheck, + IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ @@ -51,7 +51,6 @@ class ServiceBackendContext( ) extends BackendContext with Serializable {} object ServiceBackend { - private val log = Logger.getLogger(getClass.getName()) def apply( jarLocation: String, @@ -132,8 +131,7 @@ class ServiceBackend( val fs: FS, val serviceBackendContext: ServiceBackendContext, val scratchDir: String, -) extends Backend with BackendWithNoCodeCache { - import ServiceBackend.log +) extends Backend with Logging { private[this] var stageCount = 0 private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 @@ -388,6 +386,7 @@ class ServiceBackend( serviceBackendContext, new IrMetadata(), ImmutableMap.empty, + mutable.Map.empty, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index f392b809ced..baad79e9b1e 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -9,6 +9,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ @@ -307,7 +308,7 @@ class SparkBackend( override val references: mutable.Map[String, ReferenceGenome], gcsRequesterPaysProject: String, gcsRequesterPaysBuckets: String, -) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { +) extends Backend with Py4JBackendExtensions { assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() @@ -338,8 +339,8 @@ class SparkBackend( override val longLifeTempFileManager: TempFileManager = new OwningTempFileManager(fs) - private[this] val bmCache: BlockMatrixCache = - new BlockMatrixCache() + private[this] val bmCache = new BlockMatrixCache() + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) def createExecuteContextForTests( timer: ExecutionTimer, @@ -363,6 +364,7 @@ class SparkBackend( }, new IrMetadata(), ImmutableMap.empty, + mutable.Map.empty, ) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = @@ -383,6 +385,7 @@ class SparkBackend( }, new IrMetadata(), bmCache, + codeCache, )(f) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index d6cff956a38..dbe02dd5e30 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -13,11 +13,12 @@ import is.hail.types.physical.stypes.{ PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType, } import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream} -import is.hail.types.virtual.Type import is.hail.utils._ import java.io.PrintWriter +import sourcecode.Enclosing + case class CodeCacheKey( aggSigs: IndexedSeq[AggStateSig], args: Seq[(Name, EmitParamType)], @@ -32,8 +33,9 @@ case class CompiledFunction[T]( (typ, f) } -object Compile { - def apply[F: TypeInfo]( +object compile { + + def Compile[F: TypeInfo]( ctx: ExecuteContext, params: IndexedSeq[(Name, EmitParamType)], expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], @@ -42,60 +44,18 @@ object Compile { optimize: Boolean = true, print: Option[PrintWriter] = None, ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = - ctx.time { - val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) - val k = - CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody) - (ctx.backend.lookupOrCompileCachedFunction[F](k) { - - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck(ctx, ir, BindingEnv.empty) - - val returnParam = CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)) - - val fb = EmitFunctionBuilder[F]( - ctx, - "Compiled", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => - pt - }, - returnParam, - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } - * } - * - * visit(ir) } */ - - assert( - fb.mb.parameterTypeInfo == expectedCodeParamTypes, - s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", - ) - assert( - fb.mb.returnTypeInfo == expectedCodeReturnType, - s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", - ) - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length) - CompiledFunction(rt, fb.resultWithIndex(print)) - }).tuple - } -} + Impl[F, AnyVal]( + ctx, + params, + None, + expectedCodeParamTypes, + expectedCodeReturnType, + body, + optimize, + print, + ) -object CompileWithAggregators { - def apply[F: TypeInfo]( + def CompileWithAggregators[F: TypeInfo]( ctx: ExecuteContext, aggSigs: Array[AggStateSig], params: IndexedSeq[(Name, EmitParamType)], @@ -103,60 +63,74 @@ object CompileWithAggregators { expectedCodeReturnType: TypeInfo[_], body: IR, optimize: Boolean = true, + print: Option[PrintWriter] = None, ): ( Option[SingleCodeType], - (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion), + (HailClassLoader, FS, HailTaskContext, Region) => F with FunctionWithAggRegion, ) = + Impl[F, FunctionWithAggRegion]( + ctx, + params, + Some(aggSigs), + expectedCodeParamTypes, + expectedCodeReturnType, + body, + optimize, + print, + ) + + private[this] def Impl[F: TypeInfo, Mixin]( + ctx: ExecuteContext, + params: IndexedSeq[(Name, EmitParamType)], + aggSigs: Option[Array[AggStateSig]], + expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], + expectedCodeReturnType: TypeInfo[_], + body: IR, + optimize: Boolean, + print: Option[PrintWriter], + )(implicit + E: Enclosing, + N: sourcecode.Name, + ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) = ctx.time { - val normalizedBody = - NormalizeNames(ctx, body, allowFreeVariables = true) - val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody) - (ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) { - - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck( - ctx, - ir, - BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })), - ) - - val fb = EmitFunctionBuilder[F]( - ctx, - "CompiledWithAggs", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt }, - SingleCodeType.typeInfoFromType(ir.typ), - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } - * } - * - * visit(ir) } */ - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs)) - - val f = fb.resultWithIndex() - CompiledFunction( - rt, - f.asInstanceOf[( - HailClassLoader, - FS, - HailTaskContext, - Region, - ) => (F with FunctionWithAggRegion)], - ) - }).tuple + val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) + ctx.CodeCache.getOrElseUpdate( + CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), { + var ir = Subst( + body, + BindingEnv(Env.fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In(i, t) })), + ) + ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx) + TypeCheck(ctx, ir) + + val fb = EmitFunctionBuilder[F]( + ctx, + N.value, + CodeParamType(typeInfo[Region]) +: params.map(_._2), + CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)), + Some("Emit.scala"), + ) + + /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ + * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) + * } } + * + * visit(ir) } */ + + assert( + fb.mb.parameterTypeInfo == expectedCodeParamTypes, + s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", + ) + assert( + fb.mb.returnTypeInfo == expectedCodeReturnType, + s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", + ) + + val emitContext = EmitContext.analyze(ctx, ir) + val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, aggSigs) + CompiledFunction(rt, fb.resultWithIndex(print)) + }, + ).asInstanceOf[CompiledFunction[F with Mixin]].tuple } } diff --git a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala index ad51c0ce546..90778e2c96b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala +++ b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.annotations.{Region, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index f52325c6f23..6674d50ab3c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -7,6 +7,7 @@ import is.hail.expr.ir.agg.{AggStateSig, ArrayAggStateSig, GroupedStateSig} import is.hail.expr.ir.analyses.{ ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash, } +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.ndarrays.EmitNDArray import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils} diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index e7628c600d8..1eea7e096e2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -4,6 +4,7 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.backend.spark.SparkTaskContext +import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.io.BufferSpec import is.hail.linalg.BlockMatrix diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index cbd2b2bc254..4c528d79a98 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -5,7 +5,7 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer} import is.hail.backend.spark.{SparkBackend, SparkTaskContext} -import is.hail.expr.ir +import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} import is.hail.expr.ir.functions.{ BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction, } @@ -1931,7 +1931,7 @@ case class TableNativeZippedReader( val leftRef = Ref(freshName(), pLeft.virtualType) val rightRef = Ref(freshName(), pRight.virtualType) val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( leftRef.name -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)), @@ -2420,7 +2420,7 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR { else if (pred == False()) return TableValueIntermediate(tv.copy(rvd = RVD.empty(ctx, typ.canonicalRVDType))) - val (Some(BooleanSingleCodeType), f) = ir.Compile[AsmFunction3RegionLongLongBoolean]( + val (Some(BooleanSingleCodeType), f) = Compile[AsmFunction3RegionLongLongBoolean]( ctx, FastSeq( ( @@ -3035,7 +3035,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { if (extracted.aggs.isEmpty) { val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( ( @@ -3101,7 +3101,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { // 3. load in partition aggregations, comb op as necessary, serialize. // 4. load in partStarts, calculate newRow based on those results. - val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3115,7 +3115,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val serializeF = extracted.serialize(ctx, spec) - val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3138,7 +3138,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec) val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, extracted.states, FastSeq( @@ -3697,7 +3697,7 @@ case class TableKeyByAndAggregate( val localKeyType = keyType val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( ( @@ -3723,7 +3723,7 @@ case class TableKeyByAndAggregate( val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3735,7 +3735,7 @@ case class TableKeyByAndAggregate( extracted.init, ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3754,7 +3754,7 @@ case class TableKeyByAndAggregate( ) val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) = - ir.CompileWithAggregators[AsmFunction2RegionLongLong]( + CompileWithAggregators[AsmFunction2RegionLongLong]( ctx, extracted.states, FastSeq(( @@ -3897,7 +3897,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3909,7 +3909,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { extracted.init, ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3933,7 +3933,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val key = Ref(freshName(), keyType.virtualType) val value = Ref(freshName(), valueIR.typ) val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) = - ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, extracted.states, FastSeq( diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala index 4b4d5c3ac4e..84099594230 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala @@ -6,6 +6,7 @@ import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.backend.spark.SparkTaskContext import is.hail.expr.ir import is.hail.expr.ir._ +import is.hail.expr.ir.compile.CompileWithAggregators import is.hail.io.BufferSpec import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} import is.hail.types.physical.stypes.EmitType @@ -247,7 +248,7 @@ class Aggs( def deserialize(ctx: ExecuteContext, spec: BufferSpec) : ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states, FastSeq(), @@ -268,7 +269,7 @@ class Aggs( def serialize(ctx: ExecuteContext, spec: BufferSpec) : (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states, FastSeq(), @@ -305,7 +306,7 @@ class Aggs( : (() => (RegionPool, HailClassLoader, HailTaskContext)) => ( (Array[Byte], Array[Byte]) => Array[Byte], ) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states ++ states, FastSeq(), diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 859dfcaa5ba..aab36e0f847 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -4,6 +4,7 @@ import is.hail.annotations.{Annotation, ExtendedOrdering, Region, SafeRow} import is.hail.asm4s.{classInfo, AsmFunction1RegionLong, LongInfo} import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.functions.{ArrayFunctions, IRRandomness, UtilFunctions} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.rvd.RVDPartitioner diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index e3423bf9f75..e0af240efc5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -5,6 +5,7 @@ import is.hail.asm4s._ import is.hail.backend.{BroadcastValue, ExecuteContext} import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs.FS import is.hail.rvd.{RVD, RVDType} diff --git a/hail/src/main/scala/is/hail/utils/Cache.scala b/hail/src/main/scala/is/hail/utils/Cache.scala index 3aa40a6c547..8e924e7ed02 100644 --- a/hail/src/main/scala/is/hail/utils/Cache.scala +++ b/hail/src/main/scala/is/hail/utils/Cache.scala @@ -2,20 +2,29 @@ package is.hail.utils import is.hail.annotations.{Region, RegionMemory} +import scala.collection.mutable +import scala.jdk.CollectionConverters.asScalaIteratorConverter + import java.io.Closeable import java.util import java.util.Map.Entry -class Cache[K, V](capacity: Int) { +class Cache[K, V](capacity: Int) extends mutable.AbstractMap[K, V] { private[this] val m = new util.LinkedHashMap[K, V](capacity, 0.75f, true) { override def removeEldestEntry(eldest: Entry[K, V]): Boolean = size() > capacity } - def get(k: K): Option[V] = synchronized(Option(m.get(k))) + override def +=(kv: (K, V)): Cache.this.type = + synchronized { m.put(kv._1, kv._2); this } + + override def -=(key: K): Cache.this.type = + synchronized { m.remove(key); this } - def +=(p: (K, V)): Unit = synchronized(m.put(p._1, p._2)) + override def get(key: K): Option[V] = + synchronized(Option(m.get(key))) - def size: Int = synchronized(m.size()) + override def iterator: Iterator[(K, V)] = + for { e <- m.entrySet().iterator().asScala } yield (e.getKey, e.getValue) } class LongToRegionValueCache(capacity: Int) extends Closeable { diff --git a/hail/src/test/scala/is/hail/TestUtils.scala b/hail/src/test/scala/is/hail/TestUtils.scala index 926cf7753d3..ff7a4ed4868 100644 --- a/hail/src/test/scala/is/hail/TestUtils.scala +++ b/hail/src/test/scala/is/hail/TestUtils.scala @@ -4,6 +4,7 @@ import is.hail.annotations.{Region, RegionValueBuilder, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LowererUnsupportedOperation import is.hail.io.vcf.MatrixVCFReader import is.hail.types.physical.{PBaseStruct, PCanonicalArray, PType} diff --git a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala index 9cd09daf7e6..910fef630e1 100644 --- a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala @@ -4,6 +4,7 @@ import is.hail.{ExecStrategy, HailSuite} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.expr.ir.agg._ +import is.hail.expr.ir.compile.CompileWithAggregators import is.hail.io.BufferSpec import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ diff --git a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala index 0c1b0393fe2..4ab89a11177 100644 --- a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala @@ -6,6 +6,7 @@ import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg.{CollectStateSig, PhysicalAggSig, TypedStateSig} +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.ir.streams.{EmitStream, StreamUtils} import is.hail.types.VirtualTypeWithReq