Skip to content

Commit

Permalink
Backport "Add support for xsbti.compile.CompileProgress" to LTS (#20754)
Browse files Browse the repository at this point in the history
Backports #18739 to the LTS branch.

PR submitted by the release tooling.
[skip ci]
  • Loading branch information
WojciechMazur authored Jun 23, 2024
2 parents 121a512 + 0de90f3 commit b1526f2
Show file tree
Hide file tree
Showing 23 changed files with 925 additions and 106 deletions.
224 changes: 211 additions & 13 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import typer.Typer
import typer.ImportInfo.withRootImports
import Decorators._
import io.AbstractFile
import Phases.unfusedPhases
import Phases.{unfusedPhases, Phase}

import sbt.interfaces.ProgressCallback

import util._
import reporting.{Suppression, Action, Profile, ActiveProfile, NoProfile}
Expand All @@ -32,6 +34,10 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import scala.io.Codec

import Run.Progress
import scala.compiletime.uninitialized
import dotty.tools.dotc.transform.MegaPhase

/** A compiler run. Exports various methods to compile source files */
class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo {

Expand Down Expand Up @@ -155,14 +161,75 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
}

/** The source files of all late entered symbols, as a set */
private var lateFiles = mutable.Set[AbstractFile]()
private val lateFiles = mutable.Set[AbstractFile]()

/** A cache for static references to packages and classes */
val staticRefs = util.EqHashMap[Name, Denotation](initialCapacity = 1024)

/** Actions that need to be performed at the end of the current compilation run */
private var finalizeActions = mutable.ListBuffer[() => Unit]()

private var _progress: Progress | Null = null // Set if progress reporting is enabled

private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit =
foldProgress(())(op)

private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T =
val local = _progress
if local != null then
op(using ctx)(local)
else
default

def didEnterUnit(unit: CompilationUnit)(using Context): Boolean =
foldProgress(true /* should progress by default */)(_.tryEnterUnit(unit))

def canProgress()(using Context): Boolean =
foldProgress(true /* not cancelled by default */)(p => !p.checkCancellation())

def doAdvanceUnit()(using Context): Unit =
trackProgress: progress =>
progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase
progress.refreshProgress()

def doAdvanceLate()(using Context): Unit =
trackProgress: progress =>
progress.currentLateUnitCount += 1 // trace that we completed a late compilation
progress.refreshProgress()

private def doEnterPhase(currentPhase: Phase)(using Context): Unit =
trackProgress: progress =>
progress.enterPhase(currentPhase)

/** interrupt the thread and set cancellation state */
private def cancelInterrupted(): Unit =
try
trackProgress(_.cancel())
finally
Thread.currentThread().nn.interrupt()

private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit =
trackProgress: progress =>
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
if wasRan then
// add an extra traversal now that we completed a (sub)phase
progress.completedTraversalCount += 1
else
// no subphases were ran, remove traversals from expected total
progress.totalTraversals -= currentPhase.traversals

private def tryAdvanceSubPhase()(using Context): Unit =
trackProgress: progress =>
if progress.canAdvanceSubPhase then
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
if !progress.isCancelled() then
progress.tickSubphase()

/** Will be set to true if any of the compiled compilation units contains
* a pureFunctions language import.
*/
Expand Down Expand Up @@ -233,17 +300,20 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
if ctx.settings.YnoDoubleBindings.value then
ctx.base.checkNoDoubleBindings = true

def runPhases(using Context) = {
def runPhases(allPhases: Array[Phase])(using Context) = {
var lastPrintedTree: PrintedTree = NoPrintedTree
val profiler = ctx.profiler
var phasesWereAdjusted = false

for (phase <- ctx.base.allPhases)
if (phase.isRunnable)
for phase <- allPhases do
doEnterPhase(phase)
val phaseWillRun = phase.isRunnable
if phaseWillRun then
Stats.trackTime(s"phase time ms/$phase") {
val start = System.currentTimeMillis
val profileBefore = profiler.beforePhase(phase)
units = phase.runOn(units)
try units = phase.runOn(units)
catch case _: InterruptedException => cancelInterrupted()
profiler.afterPhase(phase, profileBefore)
if (ctx.settings.Xprint.value.containsPhase(phase))
for (unit <- units)
Expand All @@ -260,18 +330,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
if !Feature.ccEnabledSomewhere then
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase.prev)
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase)

end if
end if
end if
doAdvancePhase(phase, wasRan = phaseWillRun)
end for
profiler.finished()
}

val runCtx = ctx.fresh
runCtx.setProfiler(Profiler())
unfusedPhases.foreach(_.initContext(runCtx))
runPhases(using runCtx)
val fusedPhases = runCtx.base.allPhases
runCtx.withProgressCallback: cb =>
_progress = Progress(cb, this, fusedPhases.map(_.traversals).sum)
runPhases(allPhases = fusedPhases)(using runCtx)
if (!ctx.reporter.hasErrors)
Rewrites.writeBack()
suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
while (finalizeActions.nonEmpty) {
while (finalizeActions.nonEmpty && canProgress()) {
val action = finalizeActions.remove(0)
action()
}
Expand All @@ -293,10 +370,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
.withRootImports

def process()(using Context) =
ctx.typer.lateEnterUnit(doTypeCheck =>
if typeCheck then
if compiling then finalizeActions += doTypeCheck
else doTypeCheck()
ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck =>
if compiling then finalizeActions += doTypeCheck
else doTypeCheck()
)

process()(using unitCtx)
Expand Down Expand Up @@ -399,7 +475,129 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
}

object Run {

case class SubPhase(val name: String):
override def toString: String = name

class SubPhases(val phase: Phase):
require(phase.exists)

private def baseName: String = phase match
case phase: MegaPhase => phase.shortPhaseName
case phase => phase.phaseName

val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]"))

def next(using Context): Option[SubPhases] =
val next0 = phase.megaPhase.next.megaPhase
if next0.exists then Some(SubPhases(next0))
else None

def size: Int = all.size

def subPhase(index: Int) =
if index < all.size then all(index)
else baseName


private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int):
export cb.{cancel, isCancelled}

var totalTraversals: Int = initialTraversals // track how many phases we expect to run
var currentUnitCount: Int = 0 // current unit count in the current (sub)phase
var currentLateUnitCount: Int = 0 // current late unit count
var completedTraversalCount: Int = 0 // completed traversals over all files
var currentCompletedSubtraversalCount: Int = 0 // completed subphases in the current phase
var seenPhaseCount: Int = 0 // how many phases we've seen so far

private var currPhase: Phase = uninitialized // initialized by enterPhase
private var subPhases: SubPhases = uninitialized // initialized by enterPhase
private var currPhaseName: String = uninitialized // initialized by enterPhase
private var nextPhaseName: String = uninitialized // initialized by enterPhase

/** Enter into a new real phase, setting the current and next (sub)phases */
def enterPhase(newPhase: Phase)(using Context): Unit =
if newPhase ne currPhase then
currPhase = newPhase
subPhases = SubPhases(newPhase)
tickSubphase()

def canAdvanceSubPhase: Boolean =
currentCompletedSubtraversalCount + 1 < subPhases.size

/** Compute the current (sub)phase name and next (sub)phase name */
def tickSubphase()(using Context): Unit =
val index = currentCompletedSubtraversalCount
val s = subPhases
currPhaseName = s.subPhase(index)
nextPhaseName =
if index + 1 < s.all.size then s.subPhase(index + 1)
else s.next match
case None => "<end>"
case Some(next0) => next0.subPhase(0)
if seenPhaseCount > 0 then
refreshProgress()


/** Counts the number of completed full traversals over files, plus the number of units in the current phase */
private def currentProgress(): Int =
completedTraversalCount * work() + currentUnitCount + currentLateUnitCount

/**Total progress is computed as the sum of
* - the number of traversals we expect to make over all files
* - the number of late compilations
*/
private def totalProgress(): Int =
totalTraversals * work() + run.lateFiles.size

private def work(): Int = run.files.size

private def requireInitialized(): Unit =
require((currPhase: Phase | Null) != null, "enterPhase was not called")

def checkCancellation(): Boolean =
if Thread.interrupted() then cancel()
isCancelled()

/** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
def tryEnterUnit(unit: CompilationUnit): Boolean =
if checkCancellation() then false
else
requireInitialized()
cb.informUnitStarting(currPhaseName, unit)
true

/** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
def refreshProgress()(using Context): Unit =
requireInitialized()
val total = totalProgress()
if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
cancel()

extension (run: Run | Null)

/** record that the current phase has begun for the compilation unit of the current Context */
def enterUnit(unit: CompilationUnit)(using Context): Boolean =
if run != null then run.didEnterUnit(unit)
else true // don't check cancellation if we're not tracking progress

/** check progress cancellation, true if not cancelled */
def enterRegion()(using Context): Boolean =
if run != null then run.canProgress()
else true // don't check cancellation if we're not tracking progress

/** advance the unit count and record progress in the current phase */
def advanceUnit()(using Context): Unit =
if run != null then run.doAdvanceUnit()

/** if there exists another subphase, switch to it and record progress */
def enterNextSubphase()(using Context): Unit =
if run != null then run.tryAdvanceSubPhase()

/** advance the late count and record progress in the current phase */
def advanceLate()(using Context): Unit =
if run != null then run.doAdvanceLate()

def enrichedErrorMessage: Boolean = if run == null then false else run.myEnrichedErrorMessage
def enrichErrorMessage(errorMessage: String)(using Context): String =
if run == null then
Expand Down
14 changes: 12 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import scala.annotation.internal.sharable

import DenotTransformers.DenotTransformer
import dotty.tools.dotc.profile.Profiler
import dotty.tools.dotc.sbt.interfaces.IncrementalCallback
import dotty.tools.dotc.sbt.interfaces.{IncrementalCallback, ProgressCallback}
import util.Property.Key
import util.Store
import plugins._
Expand All @@ -53,8 +53,9 @@ object Contexts {
private val (notNullInfosLoc, store8) = store7.newLocation[List[NotNullInfo]]()
private val (importInfoLoc, store9) = store8.newLocation[ImportInfo | Null]()
private val (typeAssignerLoc, store10) = store9.newLocation[TypeAssigner](TypeAssigner)
private val (progressCallbackLoc, store11) = store10.newLocation[ProgressCallback | Null]()

private val initialStore = store10
private val initialStore = store11

/** The current context */
inline def ctx(using ctx: Context): Context = ctx
Expand Down Expand Up @@ -177,6 +178,14 @@ object Contexts {
val local = incCallback
local != null && local.enabled || forceRun

/** The Zinc compile progress callback implementation if we are run from Zinc, null otherwise */
def progressCallback: ProgressCallback | Null = store(progressCallbackLoc)

/** Run `op` if there exists a Zinc progress callback */
inline def withProgressCallback(inline op: ProgressCallback => Unit): Unit =
val local = progressCallback
if local != null then op(local)

/** The current plain printer */
def printerFn: Context => Printer = store(printerFnLoc)

Expand Down Expand Up @@ -675,6 +684,7 @@ object Contexts {

def setCompilerCallback(callback: CompilerCallback): this.type = updateStore(compilerCallbackLoc, callback)
def setIncCallback(callback: IncrementalCallback): this.type = updateStore(incCallbackLoc, callback)
def setProgressCallback(callback: ProgressCallback): this.type = updateStore(progressCallbackLoc, callback)
def setPrinterFn(printer: Context => Printer): this.type = updateStore(printerFnLoc, printer)
def setSettings(settingsState: SettingsState): this.type = updateStore(settingsStateLoc, settingsState)
def setRun(run: Run | Null): this.type = updateStore(runLoc, run)
Expand Down
Loading

0 comments on commit b1526f2

Please sign in to comment.