diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 23e7d8f8ecf8..830d9ad0a4d4 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -223,6 +223,22 @@ object CheckCaptures: checkNotUniversal.traverse(tpe.widen) end checkNotUniversalInUnboxedResult + trait CheckerAPI: + /** Complete symbol info of a val or a def */ + def completeDef(tree: ValOrDefDef, sym: Symbol)(using Context): Type + + extension [T <: Tree](tree: T) + + /** Set new type of the tree if none was installed yet. */ + def setNuType(tpe: Type): Unit + + /** The new type of the tree, or if none was installed, the original type */ + def nuType(using Context): Type + + /** Was a new type installed for this tree? */ + def hasNuType: Boolean + end CheckerAPI + class CheckCaptures extends Recheck, SymTransformer: thisPhase => @@ -243,7 +259,7 @@ class CheckCaptures extends Recheck, SymTransformer: val ccState1 = new CCState // Dotty problem: Rename to ccState ==> Crash in ExplicitOuter - class CaptureChecker(ictx: Context) extends Rechecker(ictx): + class CaptureChecker(ictx: Context) extends Rechecker(ictx), CheckerAPI: /** The current environment */ private val rootEnv: Env = inContext(ictx): @@ -261,10 +277,6 @@ class CheckCaptures extends Recheck, SymTransformer: */ private val todoAtPostCheck = new mutable.ListBuffer[() => Unit] - override def keepType(tree: Tree) = - super.keepType(tree) - || tree.isInstanceOf[Try] // type of `try` needs tp be checked for * escapes - /** Instantiate capture set variables appearing contra-variantly to their * upper approximation. */ @@ -286,8 +298,8 @@ class CheckCaptures extends Recheck, SymTransformer: */ private def interpolateVarsIn(tpt: Tree)(using Context): Unit = if tpt.isInstanceOf[InferredTypeTree] then - interpolator().traverse(tpt.knownType) - .showing(i"solved vars in ${tpt.knownType}", capt) + interpolator().traverse(tpt.nuType) + .showing(i"solved vars in ${tpt.nuType}", capt) for msg <- ccState.approxWarnings do report.warning(msg, tpt.srcPos) ccState.approxWarnings.clear() @@ -501,11 +513,11 @@ class CheckCaptures extends Recheck, SymTransformer: then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos) else if arg.span.exists then ("", arg.srcPos) else ("", fn.srcPos) - disallowRootCapabilitiesIn(arg.knownType, NoSymbol, + disallowRootCapabilitiesIn(arg.nuType, NoSymbol, i"Type variable $pname of $sym", "be instantiated to", addendum, pos) val param = fn.symbol.paramNamed(pname) - if param.isUseParam then markFree(arg.knownType.deepCaptureSet, pos) + if param.isUseParam then markFree(arg.nuType.deepCaptureSet, pos) end disallowCapInTypeArgs override def recheckIdent(tree: Ident, pt: Type)(using Context): Type = @@ -769,8 +781,8 @@ class CheckCaptures extends Recheck, SymTransformer: */ def checkContains(tree: TypeApply)(using Context): Unit = tree match case ContainsImpl(csArg, refArg) => - val cs = csArg.knownType.captureSet - val ref = refArg.knownType + val cs = csArg.nuType.captureSet + val ref = refArg.nuType capt.println(i"check contains $cs , $ref") ref match case ref: CaptureRef if ref.isTracked => @@ -852,7 +864,7 @@ class CheckCaptures extends Recheck, SymTransformer: case _ => (sym, "") disallowRootCapabilitiesIn( - tree.tpt.knownType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos) + tree.tpt.nuType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos) checkInferredResult(super.recheckValDef(tree, sym), tree) finally if !sym.is(Param) then @@ -1533,7 +1545,7 @@ class CheckCaptures extends Recheck, SymTransformer: private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup] override def checkUnit(unit: CompilationUnit)(using Context): Unit = - setup.setupUnit(unit.tpdTree, completeDef) + setup.setupUnit(unit.tpdTree, this) collectCapturedMutVars.traverse(unit.tpdTree) if ctx.settings.YccPrintSetup.value then @@ -1676,7 +1688,7 @@ class CheckCaptures extends Recheck, SymTransformer: traverseChildren(tp) if tree.isInstanceOf[InferredTypeTree] then - checker.traverse(tree.knownType) + checker.traverse(tree.nuType) end healTypeParam /** Under the unsealed policy: Arrays are like vars, check that their element types @@ -1716,10 +1728,10 @@ class CheckCaptures extends Recheck, SymTransformer: check(tree) def check(tree: Tree)(using Context) = tree match case TypeApply(fun, args) => - fun.knownType.widen match + fun.nuType.widen match case tl: PolyType => val normArgs = args.lazyZip(tl.paramInfos).map: (arg, bounds) => - arg.withType(arg.knownType.forceBoxStatus( + arg.withType(arg.nuType.forceBoxStatus( bounds.hi.isBoxedCapturing | bounds.lo.isBoxedCapturing)) checkBounds(normArgs, tl) args.lazyZip(tl.paramNames).foreach(healTypeParam(_, _, fun.symbol)) @@ -1739,7 +1751,7 @@ class CheckCaptures extends Recheck, SymTransformer: def traverse(t: Tree)(using Context) = t match case tree: InferredTypeTree => case tree: New => - case tree: TypeTree => checkAppliedTypesIn(tree.withKnownType) + case tree: TypeTree => checkAppliedTypesIn(tree.withType(tree.nuType)) case _ => traverseChildren(t) checkApplied.traverse(unit) end postCheck diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index c5c362dbe8dc..ebe128d7776c 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -19,19 +19,16 @@ import printing.{Printer, Texts}, Texts.{Text, Str} import collection.mutable import CCState.* import dotty.tools.dotc.util.NoSourcePosition +import CheckCaptures.CheckerAPI /** Operations accessed from CheckCaptures */ trait SetupAPI: - /** The operation to recheck a ValDef or DefDef */ - type DefRecheck = (tpd.ValOrDefDef, Symbol) => Context ?=> Type - /** Setup procedure to run for each compilation unit * @param tree the typed tree of the unit to check - * @param recheckDef the recheck method to run on completion of symbols with - * inferred (result-) types + * @param checker the capture checker which will run subsequently. */ - def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit + def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit /** Symbol is a term member of a class that was not capture checked * The info of these symbols is made fluid. @@ -378,15 +375,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: tp2 end transformExplicitType - /** Transform type of tree, and remember the transformed type as the type the tree */ - private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit = - if !tree.hasRememberedType then - val transformed = - if tree.isInferred - then transformInferredType(tree.tpe) - else transformExplicitType(tree.tpe, tptToCheck = tree) - tree.rememberType(if boxed then box(transformed) else transformed) - /** Substitute parameter symbols in `from` to paramRefs in corresponding * method or poly types `to`. We use a single BiTypeMap to do everything. * @param from a list of lists of type or term parameter symbols of a curried method @@ -436,7 +424,17 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: atPhase(thisPhase.next)(sym.info) /** A traverser that adds knownTypes and updates symbol infos */ - def setupTraverser(recheckDef: DefRecheck) = new TreeTraverserWithPreciseImportContexts: + def setupTraverser(checker: CheckerAPI) = new TreeTraverserWithPreciseImportContexts: + import checker.* + + /** Transform type of tree, and remember the transformed type as the type the tree */ + private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit = + if !tree.hasNuType then + val transformed = + if tree.isInferred + then transformInferredType(tree.tpe) + else transformExplicitType(tree.tpe, tptToCheck = tree) + tree.setNuType(if boxed then box(transformed) else transformed) /** Transform the type of a val or var or the result type of a def */ def transformResultType(tpt: TypeTree, sym: Symbol)(using Context): Unit = @@ -464,7 +462,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: traverse(parent) case _ => traverseChildren(tp) - addDescription.traverse(tpt.knownType) + addDescription.traverse(tpt.nuType) end transformResultType def traverse(tree: Tree)(using Context): Unit = @@ -504,7 +502,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: case tree @ SeqLiteral(elems, tpt: TypeTree) => traverse(elems) - tpt.rememberType(box(transformInferredType(tpt.tpe))) + tpt.setNuType(box(transformInferredType(tpt.tpe))) case tree: Block => inNestedLevel(traverseChildren(tree)) @@ -537,22 +535,22 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: // with special treatment for constructors. def localReturnType = if sym.isConstructor then constrReturnType(sym.info, sym.paramSymss) - else tree.tpt.knownType + else tree.tpt.nuType // A test whether parameter signature might change. This returns true if one of - // the parameters has a remembered type. The idea here is that we store a remembered + // the parameters has a new type installee. The idea here is that we store a new // type only if the transformed type is different from the original. def paramSignatureChanges = tree.match case tree: DefDef => tree.paramss.nestedExists: - case param: ValDef => param.tpt.hasRememberedType - case param: TypeDef => param.rhs.hasRememberedType + case param: ValDef => param.tpt.hasNuType + case param: TypeDef => param.rhs.hasNuType case _ => false // A symbol's signature changes if some of its parameter types or its result type // have a new type installed here (meaning hasRememberedType is true) def signatureChanges = - tree.tpt.hasRememberedType && !sym.isConstructor || paramSignatureChanges + tree.tpt.hasNuType && !sym.isConstructor || paramSignatureChanges // Replace an existing symbol info with inferred types where capture sets of // TypeParamRefs and TermParamRefs are put in correspondence by BiTypeMaps with the @@ -616,7 +614,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: capt.println(i"forcing $sym, printing = ${ctx.mode.is(Mode.Printing)}") //if ctx.mode.is(Mode.Printing) then new Error().printStackTrace() denot.info = newInfo - recheckDef(tree, sym) + completeDef(tree, sym) updateInfo(sym, updatedInfo) case tree: Bind => @@ -833,8 +831,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: /** Run setup on a compilation unit with given `tree`. * @param recheckDef the function to run for completing a val or def */ - def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit = - setupTraverser(recheckDef).traverse(tree)(using ctx.withPhase(thisPhase)) + def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit = + setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase)) // ------ Checks to run after main capture checking -------------------------- diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index d3173cef252d..172ae337d6e6 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -28,18 +28,29 @@ import dotty.tools.dotc.cc.boxed object Recheck: import tpd.* - /** Attachment key for rechecked types of TypeTrees */ - val RecheckedType = Property.Key[Type] - - val addRecheckedTypes = new TreeMap: - override def transform(tree: Tree)(using Context): Tree = - try - val tree1 = super.transform(tree) - tree.getAttachment(RecheckedType) match - case Some(tpe) => tree1.withType(tpe) - case None => tree1 - catch - case _:TypeError => tree + /** Attachment key for a toplevel tree of a unit that contains a map + * from nodes in that tree to their rechecked "new" types + */ + val RecheckedTypes = Property.Key[util.EqHashMap[Tree, Type]] + + /** If tree carries a RecheckedTypes attachment, use the associated `nuTypes` + * map to produce a new tree that contains at each node the type in the + * map as the node's .tpe field + */ + def addRecheckedTypes(tree: Tree)(using Context): Tree = + tree.getAttachment(RecheckedTypes) match + case Some(nuTypes) => + val withNuTypes = new TreeMap: + override def transform(tree: Tree)(using Context): Tree = + try + val tree1 = super.transform(tree) + val tpe = nuTypes.lookup(tree) + if tpe != null then tree1.withType(tpe) else tree1 + catch + case _: TypeError => tree + withNuTypes.transform(tree) + case None => + tree extension (sym: Symbol)(using Context) @@ -61,30 +72,6 @@ object Recheck: val symd = sym.denot symd.validFor.firstPhaseId == phase.id + 1 && (sym.originDenotation ne symd) - extension [T <: Tree](tree: T) - - /** Remember `tpe` as the type of `tree`, which might be different from the - * type stored in the tree itself, unless a type was already remembered for `tree`. - */ - def rememberType(tpe: Type)(using Context): Unit = - if !tree.hasAttachment(RecheckedType) then rememberTypeAlways(tpe) - - /** Remember `tpe` as the type of `tree`, which might be different from the - * type stored in the tree itself - */ - def rememberTypeAlways(tpe: Type)(using Context): Unit = - if tpe ne tree.knownType then tree.putAttachment(RecheckedType, tpe) - - /** The remembered type of the tree, or if none was installed, the original type */ - def knownType: Type = - tree.attachmentOrElse(RecheckedType, tree.tpe) - - def hasRememberedType: Boolean = tree.hasAttachment(RecheckedType) - - def withKnownType(using Context): T = tree.getAttachment(RecheckedType) match - case Some(tpe) => tree.withType(tpe).asInstanceOf[T] - case None => tree - /** Map ExprType => T to () ?=> T (and analogously for pure versions). * Even though this phase runs after ElimByName, ExprTypes can still occur * as by-name arguments of applied types. See note in doc comment for @@ -172,17 +159,32 @@ abstract class Recheck extends Phase, SymTransformer: class Rechecker(@constructorOnly ictx: Context): private val ta = ictx.typeAssigner - /** If true, remember types of all tree nodes in attachments so that they - * can be retrieved with `knownType` - */ - private val keepAllTypes = inContext(ictx) { - ictx.settings.Xprint.value.containsPhase(thisPhase) - } + private val nuTypes = util.EqHashMap[Tree, Type]() + + extension [T <: Tree](tree: T) + + /** Set new type of the tree if none was installed yet and the new type is different + * from the current type. + */ + def setNuType(tpe: Type): Unit = + if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then nuTypes(tree) = tpe + + /** The new type of the tree, or if none was installed, the original type */ + def nuType(using Context): Type = + val ntpe = nuTypes.lookup(tree) + if ntpe != null then ntpe else tree.tpe + + /** Was a new type installed for this tree? */ + def hasNuType: Boolean = + nuTypes.lookup(tree) != null + end extension - /** Should type of `tree` be kept in an attachment so that it can be retrieved with - * `knownType`? By default true only is `keepAllTypes` hold, but can be overridden. + /** If true, remember the new types of nodes in this compilation unit + * as an attachment in the unit's tpdTree node. By default, this is + * enabled when -Xprint:cc is set. Can be overridden. */ - def keepType(tree: Tree): Boolean = keepAllTypes + def keepNuTypes(using Context): Boolean = + ctx.settings.Xprint.value.containsPhase(thisPhase) /** A map from NamedTypes to the denotations they had before this phase. * Needed so that we can `reset` them after this phase. @@ -343,7 +345,6 @@ abstract class Recheck extends Phase, SymTransformer: def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type = val funtpe = recheck(tree.fun) - tree.fun.rememberType(funtpe) // remember type to support later bounds checks funtpe.widen match case fntpe: PolyType => assert(fntpe.paramInfos.hasSameLengthAs(tree.args)) @@ -459,7 +460,7 @@ abstract class Recheck extends Phase, SymTransformer: seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes)) def recheckTypeTree(tree: TypeTree)(using Context): Type = - tree.knownType // allows to install new types at Setup + tree.nuType // allows to install new types at Setup def recheckAnnotated(tree: Annotated)(using Context): Type = tree.tpe match @@ -558,7 +559,7 @@ abstract class Recheck extends Phase, SymTransformer: */ def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = val tpe1 = checkConforms(tpe, pt, tree) - if keepType(tree) then tree.rememberType(tpe1) + tree.setNuType(tpe1) tpe1 def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = @@ -617,6 +618,7 @@ abstract class Recheck extends Phase, SymTransformer: def checkUnit(unit: CompilationUnit)(using Context): Unit = recheck(unit.tpdTree) + if keepNuTypes then unit.tpdTree.putAttachment(RecheckedTypes, nuTypes) end Rechecker @@ -624,7 +626,8 @@ abstract class Recheck extends Phase, SymTransformer: override def show(tree: untpd.Tree)(using Context): String = atPhase(thisPhase): withMode(Mode.Printing): - super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree])) + super.show: + addRecheckedTypes(tree.asInstanceOf[tpd.Tree]) end Recheck /** A class that can be used to test basic rechecking without any customaization */