Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for xsbti.compile.CompileProgress #18739

Merged
merged 13 commits into from
Oct 31, 2023
224 changes: 211 additions & 13 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import typer.Typer
import typer.ImportInfo.withRootImports
import Decorators._
import io.AbstractFile
import Phases.unfusedPhases
import Phases.{unfusedPhases, Phase}

import sbt.interfaces.ProgressCallback

import util._
import reporting.{Suppression, Action, Profile, ActiveProfile, NoProfile}
Expand All @@ -32,6 +34,10 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import scala.io.Codec

import Run.Progress
import scala.compiletime.uninitialized
import dotty.tools.dotc.transform.MegaPhase

/** A compiler run. Exports various methods to compile source files */
class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo {

Expand Down Expand Up @@ -155,14 +161,75 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
}

/** The source files of all late entered symbols, as a set */
private var lateFiles = mutable.Set[AbstractFile]()
private val lateFiles = mutable.Set[AbstractFile]()

/** A cache for static references to packages and classes */
val staticRefs = util.EqHashMap[Name, Denotation](initialCapacity = 1024)

/** Actions that need to be performed at the end of the current compilation run */
private var finalizeActions = mutable.ListBuffer[() => Unit]()

private var _progress: Progress | Null = null // Set if progress reporting is enabled

private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit =
foldProgress(())(op)

private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T =
val local = _progress
if local != null then
op(using ctx)(local)
bishabosha marked this conversation as resolved.
Show resolved Hide resolved
else
default

def didEnterUnit(unit: CompilationUnit)(using Context): Boolean =
foldProgress(true /* should progress by default */)(_.tryEnterUnit(unit))

def canProgress()(using Context): Boolean =
foldProgress(true /* not cancelled by default */)(p => !p.checkCancellation())

def doAdvanceUnit()(using Context): Unit =
trackProgress: progress =>
progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase
progress.refreshProgress()

def doAdvanceLate()(using Context): Unit =
trackProgress: progress =>
progress.currentLateUnitCount += 1 // trace that we completed a late compilation
progress.refreshProgress()

private def doEnterPhase(currentPhase: Phase)(using Context): Unit =
trackProgress: progress =>
progress.enterPhase(currentPhase)

/** interrupt the thread and set cancellation state */
private def cancelInterrupted(): Unit =
try
trackProgress(_.cancel())
finally
Thread.currentThread().nn.interrupt()

private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit =
trackProgress: progress =>
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
if wasRan then
// add an extra traversal now that we completed a (sub)phase
progress.completedTraversalCount += 1
else
// no subphases were ran, remove traversals from expected total
progress.totalTraversals -= currentPhase.traversals

private def tryAdvanceSubPhase()(using Context): Unit =
trackProgress: progress =>
if progress.canAdvanceSubPhase then
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
if !progress.isCancelled() then
progress.tickSubphase()

/** Will be set to true if any of the compiled compilation units contains
* a pureFunctions language import.
*/
Expand Down Expand Up @@ -233,17 +300,20 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
if ctx.settings.YnoDoubleBindings.value then
ctx.base.checkNoDoubleBindings = true

def runPhases(using Context) = {
def runPhases(allPhases: Array[Phase])(using Context) = {
var lastPrintedTree: PrintedTree = NoPrintedTree
val profiler = ctx.profiler
var phasesWereAdjusted = false

for (phase <- ctx.base.allPhases)
if (phase.isRunnable)
for phase <- allPhases do
doEnterPhase(phase)
val phaseWillRun = phase.isRunnable
if phaseWillRun then
Stats.trackTime(s"phase time ms/$phase") {
val start = System.currentTimeMillis
val profileBefore = profiler.beforePhase(phase)
units = phase.runOn(units)
try units = phase.runOn(units)
catch case _: InterruptedException => cancelInterrupted()
profiler.afterPhase(phase, profileBefore)
if (ctx.settings.Xprint.value.containsPhase(phase))
for (unit <- units)
Expand All @@ -261,18 +331,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
if !Feature.ccEnabledSomewhere then
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase.prev)
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase)

end if
end if
end if
doAdvancePhase(phase, wasRan = phaseWillRun)
end for
profiler.finished()
}

val runCtx = ctx.fresh
runCtx.setProfiler(Profiler())
unfusedPhases.foreach(_.initContext(runCtx))
runPhases(using runCtx)
val fusedPhases = runCtx.base.allPhases
runCtx.withProgressCallback: cb =>
_progress = Progress(cb, this, fusedPhases.map(_.traversals).sum)
runPhases(allPhases = fusedPhases)(using runCtx)
if (!ctx.reporter.hasErrors)
Rewrites.writeBack()
suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
while (finalizeActions.nonEmpty) {
while (finalizeActions.nonEmpty && canProgress()) {
val action = finalizeActions.remove(0)
action()
}
Expand All @@ -294,10 +371,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
.withRootImports

def process()(using Context) =
ctx.typer.lateEnterUnit(doTypeCheck =>
if typeCheck then
if compiling then finalizeActions += doTypeCheck
else doTypeCheck()
ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck =>
if compiling then finalizeActions += doTypeCheck
else doTypeCheck()
)

process()(using unitCtx)
Expand Down Expand Up @@ -400,7 +476,129 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
}

object Run {

case class SubPhase(val name: String):
override def toString: String = name

class SubPhases(val phase: Phase):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason this exists is because TyperPhase is really like three phases, it might be good to allow cancellation in the middle of the phase

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any reason not to split TyperPhase into three phases?

require(phase.exists)

private def baseName: String = phase match
case phase: MegaPhase => phase.shortPhaseName
case phase => phase.phaseName

val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]"))

def next(using Context): Option[SubPhases] =
val next0 = phase.megaPhase.next.megaPhase
if next0.exists then Some(SubPhases(next0))
else None

def size: Int = all.size

def subPhase(index: Int) =
if index < all.size then all(index)
else baseName


private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int):
export cb.{cancel, isCancelled}

var totalTraversals: Int = initialTraversals // track how many phases we expect to run
var currentUnitCount: Int = 0 // current unit count in the current (sub)phase
var currentLateUnitCount: Int = 0 // current late unit count
var completedTraversalCount: Int = 0 // completed traversals over all files
var currentCompletedSubtraversalCount: Int = 0 // completed subphases in the current phase
var seenPhaseCount: Int = 0 // how many phases we've seen so far

private var currPhase: Phase = uninitialized // initialized by enterPhase
private var subPhases: SubPhases = uninitialized // initialized by enterPhase
private var currPhaseName: String = uninitialized // initialized by enterPhase
private var nextPhaseName: String = uninitialized // initialized by enterPhase

/** Enter into a new real phase, setting the current and next (sub)phases */
def enterPhase(newPhase: Phase)(using Context): Unit =
if newPhase ne currPhase then
currPhase = newPhase
subPhases = SubPhases(newPhase)
tickSubphase()

def canAdvanceSubPhase: Boolean =
currentCompletedSubtraversalCount + 1 < subPhases.size

/** Compute the current (sub)phase name and next (sub)phase name */
def tickSubphase()(using Context): Unit =
val index = currentCompletedSubtraversalCount
val s = subPhases
currPhaseName = s.subPhase(index)
nextPhaseName =
if index + 1 < s.all.size then s.subPhase(index + 1)
else s.next match
case None => "<end>"
case Some(next0) => next0.subPhase(0)
if seenPhaseCount > 0 then
refreshProgress()


/** Counts the number of completed full traversals over files, plus the number of units in the current phase */
private def currentProgress(): Int =
completedTraversalCount * work() + currentUnitCount + currentLateUnitCount

/**Total progress is computed as the sum of
* - the number of traversals we expect to make over all files
* - the number of late compilations
*/
private def totalProgress(): Int =
totalTraversals * work() + run.lateFiles.size

private def work(): Int = run.files.size

private def requireInitialized(): Unit =
require((currPhase: Phase | Null) != null, "enterPhase was not called")

def checkCancellation(): Boolean =
if Thread.interrupted() then cancel()
isCancelled()

/** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
def tryEnterUnit(unit: CompilationUnit): Boolean =
if checkCancellation() then false
else
requireInitialized()
cb.informUnitStarting(currPhaseName, unit)
true

/** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
def refreshProgress()(using Context): Unit =
requireInitialized()
val total = totalProgress()
if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
cancel()

extension (run: Run | Null)

/** record that the current phase has begun for the compilation unit of the current Context */
def enterUnit(unit: CompilationUnit)(using Context): Boolean =
if run != null then run.didEnterUnit(unit)
else true // don't check cancellation if we're not tracking progress

/** check progress cancellation, true if not cancelled */
def enterRegion()(using Context): Boolean =
if run != null then run.canProgress()
else true // don't check cancellation if we're not tracking progress

/** advance the unit count and record progress in the current phase */
def advanceUnit()(using Context): Unit =
if run != null then run.doAdvanceUnit()

/** if there exists another subphase, switch to it and record progress */
def enterNextSubphase()(using Context): Unit =
if run != null then run.tryAdvanceSubPhase()

/** advance the late count and record progress in the current phase */
def advanceLate()(using Context): Unit =
if run != null then run.doAdvanceLate()

def enrichedErrorMessage: Boolean = if run == null then false else run.myEnrichedErrorMessage
def enrichErrorMessage(errorMessage: String)(using Context): String =
if run == null then
Expand Down
14 changes: 12 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import scala.annotation.internal.sharable

import DenotTransformers.DenotTransformer
import dotty.tools.dotc.profile.Profiler
import dotty.tools.dotc.sbt.interfaces.IncrementalCallback
import dotty.tools.dotc.sbt.interfaces.{IncrementalCallback, ProgressCallback}
import util.Property.Key
import util.Store
import plugins._
Expand All @@ -53,8 +53,9 @@ object Contexts {
private val (notNullInfosLoc, store8) = store7.newLocation[List[NotNullInfo]]()
private val (importInfoLoc, store9) = store8.newLocation[ImportInfo | Null]()
private val (typeAssignerLoc, store10) = store9.newLocation[TypeAssigner](TypeAssigner)
private val (progressCallbackLoc, store11) = store10.newLocation[ProgressCallback | Null]()

private val initialStore = store10
private val initialStore = store11

/** The current context */
inline def ctx(using ctx: Context): Context = ctx
Expand Down Expand Up @@ -177,6 +178,14 @@ object Contexts {
val local = incCallback
local != null && local.enabled || forceRun

/** The Zinc compile progress callback implementation if we are run from Zinc, null otherwise */
def progressCallback: ProgressCallback | Null = store(progressCallbackLoc)

/** Run `op` if there exists a Zinc progress callback */
inline def withProgressCallback(inline op: ProgressCallback => Unit): Unit =
val local = progressCallback
if local != null then op(local)

/** The current plain printer */
def printerFn: Context => Printer = store(printerFnLoc)

Expand Down Expand Up @@ -675,6 +684,7 @@ object Contexts {

def setCompilerCallback(callback: CompilerCallback): this.type = updateStore(compilerCallbackLoc, callback)
def setIncCallback(callback: IncrementalCallback): this.type = updateStore(incCallbackLoc, callback)
def setProgressCallback(callback: ProgressCallback): this.type = updateStore(progressCallbackLoc, callback)
def setPrinterFn(printer: Context => Printer): this.type = updateStore(printerFnLoc, printer)
def setSettings(settingsState: SettingsState): this.type = updateStore(settingsStateLoc, settingsState)
def setRun(run: Run | Null): this.type = updateStore(runLoc, run)
Expand Down
Loading
Loading