From 0de90f3efc7697f7513ccd805a060ea101ce3c1a Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 25 Oct 2023 21:19:20 +0200 Subject: [PATCH] simplify subphase traversal [Cherry-picked b51077277085319da721b8aa2aaa0b4fafcbafa5] --- compiler/src/dotty/tools/dotc/Run.scala | 40 ++++++++++------- .../src/dotty/tools/dotc/core/Phases.scala | 9 +++- .../dotty/tools/dotc/typer/TyperPhase.scala | 44 +++++++++---------- .../tools/dotc/sbt/ProgressCallbackTest.scala | 2 +- .../xsbt/CompileProgressSpecification.scala | 2 +- 5 files changed, 56 insertions(+), 41 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 098c8342cfca..40a343fb1267 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -220,14 +220,15 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint // no subphases were ran, remove traversals from expected total progress.totalTraversals -= currentPhase.traversals - private def doAdvanceSubPhase()(using Context): Unit = + private def tryAdvanceSubPhase()(using Context): Unit = trackProgress: progress => - progress.currentUnitCount = 0 // reset unit count in current (sub)phase - progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase - progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase - progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase - if !progress.isCancelled() then - progress.tickSubphase() + if progress.canAdvanceSubPhase then + progress.currentUnitCount = 0 // reset unit count in current (sub)phase + progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase + progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase + progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase + if !progress.isCancelled() then + progress.tickSubphase() /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. @@ -475,6 +476,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint object Run { + case class SubPhase(val name: String): + override def toString: String = name + class SubPhases(val phase: Phase): require(phase.exists) @@ -482,13 +486,15 @@ object Run { case phase: MegaPhase => phase.shortPhaseName case phase => phase.phaseName - val all = IArray.from(phase.subPhases.map(sub => s"$baseName ($sub)")) + val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]")) def next(using Context): Option[SubPhases] = val next0 = phase.megaPhase.next.megaPhase if next0.exists then Some(SubPhases(next0)) else None + def size: Int = all.size + def subPhase(index: Int) = if index < all.size then all(index) else baseName @@ -510,14 +516,17 @@ object Run { private var nextPhaseName: String = uninitialized // initialized by enterPhase /** Enter into a new real phase, setting the current and next (sub)phases */ - private[Run] def enterPhase(newPhase: Phase)(using Context): Unit = + def enterPhase(newPhase: Phase)(using Context): Unit = if newPhase ne currPhase then currPhase = newPhase subPhases = SubPhases(newPhase) tickSubphase() + def canAdvanceSubPhase: Boolean = + currentCompletedSubtraversalCount + 1 < subPhases.size + /** Compute the current (sub)phase name and next (sub)phase name */ - private[Run] def tickSubphase()(using Context): Unit = + def tickSubphase()(using Context): Unit = val index = currentCompletedSubtraversalCount val s = subPhases currPhaseName = s.subPhase(index) @@ -546,12 +555,12 @@ object Run { private def requireInitialized(): Unit = require((currPhase: Phase | Null) != null, "enterPhase was not called") - private[Run] def checkCancellation(): Boolean = + def checkCancellation(): Boolean = if Thread.interrupted() then cancel() isCancelled() /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */ - private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean = + def tryEnterUnit(unit: CompilationUnit): Boolean = if checkCancellation() then false else requireInitialized() @@ -559,7 +568,7 @@ object Run { true /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */ - private[Run] def refreshProgress()(using Context): Unit = + def refreshProgress()(using Context): Unit = requireInitialized() val total = totalProgress() if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then @@ -581,8 +590,9 @@ object Run { def advanceUnit()(using Context): Unit = if run != null then run.doAdvanceUnit() - def advanceSubPhase()(using Context): Unit = - if run != null then run.doAdvanceSubPhase() + /** if there exists another subphase, switch to it and record progress */ + def enterNextSubphase()(using Context): Unit = + if run != null then run.tryAdvanceSubPhase() /** advance the late count and record progress in the current phase */ def advanceLate()(using Context): Unit = diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 31e07001a4a2..d6a49186b539 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -318,7 +318,7 @@ object Phases { def runsAfter: Set[String] = Set.empty /** for purposes of progress tracking, overridden in TyperPhase */ - def subPhases: List[String] = Nil + def subPhases: List[Run.SubPhase] = Nil final def traversals: Int = if subPhases.isEmpty then 1 else subPhases.length /** @pre `isRunnable` returns true */ @@ -460,6 +460,13 @@ object Phases { else false + inline def runSubPhase[T](id: Run.SubPhase)(inline body: (Run.SubPhase, Context) ?=> T)(using Context): T = + given Run.SubPhase = id + try + body + finally + ctx.run.enterNextSubphase() + /** Do not run if compile progress has been cancelled */ final def cancellable(body: Context ?=> Unit)(using Context): Boolean = if ctx.run.enterRegion() then diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index 10796dce2e7c..857ed1bad4d9 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -3,6 +3,7 @@ package dotc package typer import core._ +import Run.SubPhase import Phases._ import Contexts._ import Symbols._ @@ -31,13 +32,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { // Run regardless of parsing errors override def isRunnable(implicit ctx: Context): Boolean = true - def enterSyms(using Context): Boolean = monitor("indexing") { + def enterSyms(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit ctx.typer.index(unit.untpdTree) typr.println("entered: " + unit.source) } - def typeCheck(using Context): Boolean = monitor("typechecking") { + def typeCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit try if !unit.suspended then @@ -49,7 +50,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { catch case _: CompilationUnit.SuspendException => () } - def javaCheck(using Context): Boolean = monitor("checking java") { + def javaCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit if unit.isJava then JavaChecks.check(unit.tpdTree) @@ -58,10 +59,11 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { protected def discardAfterTyper(unit: CompilationUnit)(using Context): Boolean = unit.isJava || unit.suspended - /** Keep synchronised with `monitor` subcalls */ - override def subPhases: List[String] = List("indexing", "typechecking", "checking java") + override val subPhases: List[SubPhase] = List( + SubPhase("indexing"), SubPhase("typechecking"), SubPhase("checkingJava")) override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = + val List(Indexing @ _, Typechecking @ _, CheckingJava @ _) = subPhases: @unchecked val unitContexts = for unit <- units yield val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit) @@ -72,14 +74,12 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { else newCtx - val unitContexts0 = - try - for - unitContext <- unitContexts - if enterSyms(using unitContext) - yield unitContext - finally - ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" + val unitContexts0 = runSubPhase(Indexing) { + for + unitContext <- unitContexts + if enterSyms(using unitContext) + yield unitContext + } ctx.base.parserPhase match { case p: ParserPhase => @@ -91,23 +91,21 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - val unitContexts1 = - try - for - unitContext <- unitContexts0 - if typeCheck(using unitContext) - yield unitContext - finally - ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" + val unitContexts1 = runSubPhase(Typechecking) { + for + unitContext <- unitContexts0 + if typeCheck(using unitContext) + yield unitContext + } record("total trees after typer", ast.Trees.ntrees) - val unitContexts2 = + val unitContexts2 = runSubPhase(CheckingJava) { for unitContext <- unitContexts1 if javaCheck(using unitContext) // after typechecking to avoid cycles yield unitContext - + } val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) newUnits diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala index e6e67b997aae..489dc0f1759c 100644 --- a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -57,7 +57,7 @@ final class ProgressCallbackTest extends DottyTest: @Test def cancelMidTyper: Unit = - inspectCancellationAtPhase("typer (typechecking)") + inspectCancellationAtPhase("typer[typechecking]") @Test def cancelErasure: Unit = diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala index 45f9daa70e05..bcdac0547e75 100644 --- a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -52,7 +52,7 @@ class CompileProgressSpecification { val someExpectedPhases = // just check some "fundamental" phases, don't put all phases to avoid brittleness Set( "parser", - "typer (indexing)", "typer (typechecking)", "typer (checking java)", + "typer[indexing]", "typer[typechecking]", "typer[checkingJava]", "sbt-deps", "posttyper", "sbt-api",