Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential soundness hole when adding references to a mapped capture set #18758

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 5 additions & 44 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ private val Captures: Key[CaptureSet] = Key()

object ccConfig:

/** Switch whether unpickled function types and byname types should be mapped to
* impure types. With the new gradual typing using Fluid capture sets, this should
* be no longer needed. Also, it has bad interactions with pickling tests.
/** If true, allow mappping capture set variables under captureChecking with maps that are neither
* bijective nor idempotent. We currently do now know how to do this correctly in all
* cases, though.
*/
private[cc] val adaptUnpickledFunctionTypes = false
inline val allowUnsoundMaps = false

/** If true, use `sealed` as encapsulation mechanism instead of the
* previous global retriction that `cap` can't be boxed or unboxed.
Expand All @@ -48,7 +48,7 @@ def isCaptureCheckingOrSetup(using Context): Boolean =
*/
def depFun(args: List[Type], resultType: Type, isContextual: Boolean, paramNames: List[TermName] = Nil)(using Context): Type =
val make = MethodType.companion(isContextual = isContextual)
val mt =
val mt =
if paramNames.length == args.length then make(paramNames, args, resultType)
else make(args, resultType)
mt.toFunctionType(alwaysDependent = true)
Expand Down Expand Up @@ -106,22 +106,6 @@ extension (tree: Tree)
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
case _ => Nil

/** Under pureFunctions, add a @retainsByName(*)` annotation to the argument of
* a by name parameter type, turning the latter into an impure by name parameter type.
*/
def adaptByNameArgUnderPureFuns(using Context): Tree =
if ccConfig.adaptUnpickledFunctionTypes && Feature.pureFunsEnabledSomewhere then
val rbn = defn.RetainsByNameAnnot
Annotated(tree,
New(rbn.typeRef).select(rbn.primaryConstructor).appliedTo(
Typed(
SeqLiteral(ref(defn.captureRoot) :: Nil, TypeTree(defn.AnyType)),
TypeTree(defn.RepeatedParamType.appliedTo(defn.AnyType))
)
)
)
else tree

extension (tp: Type)

/** @pre `tp` is a CapturingType */
Expand Down Expand Up @@ -199,29 +183,6 @@ extension (tp: Type)
case _ =>
tp

/** Under pureFunctions, map regular function type to impure function type
*/
def adaptFunctionTypeUnderPureFuns(using Context): Type = tp match
case AppliedType(fn, args)
if ccConfig.adaptUnpickledFunctionTypes && Feature.pureFunsEnabledSomewhere && defn.isFunctionClass(fn.typeSymbol) =>
val fname = fn.typeSymbol.name
defn.FunctionType(
fname.functionArity,
isContextual = fname.isContextFunction,
isImpure = true).appliedTo(args)
case _ =>
tp

/** Under pureFunctions, add a @retainsByName(*)` annotation to the argument of
* a by name parameter type, turning the latter into an impure by name parameter type.
*/
def adaptByNameArgUnderPureFuns(using Context): Type =
if ccConfig.adaptUnpickledFunctionTypes && Feature.pureFunsEnabledSomewhere then
AnnotatedType(tp,
CaptureAnnotation(CaptureSet.universal, boxed = false)(defn.RetainsByNameAnnot))
else
tp

/** Is type known to be always pure by its class structure,
* so that adding a capture set to it would not make sense?
*/
Expand Down
43 changes: 38 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import util.{SimpleIdentitySet, Property}
import typer.ErrorReporting.Addenda
import util.common.alwaysTrue
import scala.collection.mutable
import config.Config.ccAllowUnsoundMaps

/** A class for capture sets. Capture sets can be constants or variables.
* Capture sets support inclusion constraints <:< where <:< is subcapturing.
Expand Down Expand Up @@ -667,7 +666,7 @@ object CaptureSet:

private def mapIsIdempotent = tm.isInstanceOf[IdempotentCaptRefMap]

assert(ccAllowUnsoundMaps || mapIsIdempotent, tm.getClass)
assert(ccConfig.allowUnsoundMaps || mapIsIdempotent, tm.getClass)

private def whereCreated(using Context): String =
if stack == null then ""
Expand All @@ -683,7 +682,9 @@ object CaptureSet:
// `r` is _one_ possible solution in `source` that would make an `r` appear in this set.
// It's not necessarily the only possible solution, so the scheme is incomplete.
source.tryInclude(elem, this)
else if !mapIsIdempotent && variance <= 0 && !origin.isConst && (origin ne initial) && (origin ne source) then
else if ccConfig.allowUnsoundMaps && !mapIsIdempotent
&& variance <= 0 && !origin.isConst && (origin ne initial) && (origin ne source)
then
// The map is neither a BiTypeMap nor an idempotent type map.
// In that case there's no much we can do.
// The scheme then does not propagate added elements back to source and rejects adding
Expand All @@ -697,8 +698,11 @@ object CaptureSet:
def propagateIf(cond: Boolean): CompareResult =
if cond then propagate else CompareResult.OK

if origin eq source then // elements have to be mapped
val mapped = extrapolateCaptureRef(elem, tm, variance)
val mapped = extrapolateCaptureRef(elem, tm, variance)
def isFixpoint =
mapped.isConst && mapped.elems.size == 1 && mapped.elems.contains(elem)

def addMapped =
val added = mapped.elems.filter(!accountsFor(_))
addNewElems(added)
.andAlso:
Expand All @@ -707,6 +711,35 @@ object CaptureSet:
else CompareResult.Fail(this :: Nil)
.andAlso:
propagateIf(!added.isEmpty)

def failNoFixpoint =
val reason =
if variance <= 0 then i"the set's variance is $variance"
else i"$elem gets mapped to $mapped, which is not a supercapture."
report.warning(em"""trying to add $elem from unrecognized source $origin of mapped set $this$whereCreated
|The reference cannot be added since $reason""")
CompareResult.Fail(this :: Nil)

if origin eq source then // elements have to be mapped
addMapped
.andAlso:
if mapped.isConst then CompareResult.OK
else if mapped.asVar.recordDepsState() then { addAsDependentTo(mapped); CompareResult.OK }
else CompareResult.Fail(this :: Nil)
else if !isFixpoint then
// We did not yet observe the !isFixpoint condition in our tests, but it's a
// theoretical possibility. In that case, it would be inconsistent to both
// add `elem` to the set and back-propagate it. But if `{elem} <:< tm(elem)`
// and the variance of the set is positive, we can soundly add `tm(ref)` to
// the set while back-propagating `ref` as before. Otherwise there's nothing
// obvious left to do except fail (which is always sound).
if variance > 0
&& elem.singletonCaptureSet.subCaptures(mapped, frozen = true).isOK then
// widen to fixpoint. mapped is known to be a fixpoint since tm is idempotent.
// The widening is sound, but loses completeness.
addMapped
else
failNoFixpoint
else if accountsFor(elem) then
CompareResult.OK
else
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/dotty/tools/dotc/config/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,4 @@ object Config {
*/
inline val checkLevelsOnConstraints = false
inline val checkLevelsOnInstantiation = true

/** If true, print capturing types in the form `{c} T`.
* If false, print them in the form `T @retains(c)`.
*/
inline val printCaptureSetsAsPrefix = true

/** If true, allow mappping capture set variables under captureChecking with maps that are neither
* bijective nor idempotent. We currently do now know how to do this correctly in all
* cases, though.
*/
inline val ccAllowUnsoundMaps = false
}
17 changes: 4 additions & 13 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import ast.{Trees, tpd, untpd}
import Trees._
import Decorators._
import transform.SymUtils._
import cc.{adaptFunctionTypeUnderPureFuns, adaptByNameArgUnderPureFuns}
import dotty.tools.dotc.quoted.QuotePatterns

import dotty.tools.tasty.{TastyBuffer, TastyReader}
Expand Down Expand Up @@ -383,7 +382,7 @@ class TreeUnpickler(reader: TastyReader,
// Note that the lambda "rt => ..." is not equivalent to a wildcard closure!
// Eta expansion of the latter puts readType() out of the expression.
case APPLIEDtype =>
postProcessFunction(readType().appliedTo(until(end)(readType())))
readType().appliedTo(until(end)(readType()))
case TYPEBOUNDS =>
val lo = readType()
if nothingButMods(end) then
Expand Down Expand Up @@ -460,8 +459,7 @@ class TreeUnpickler(reader: TastyReader,
val ref = readAddr()
typeAtAddr.getOrElseUpdate(ref, forkAt(ref).readType())
case BYNAMEtype =>
val arg = readType()
ExprType(if withPureFuns then arg else arg.adaptByNameArgUnderPureFuns)
ExprType(readType())
case _ =>
ConstantType(readConstant(tag))
}
Expand Down Expand Up @@ -495,12 +493,6 @@ class TreeUnpickler(reader: TastyReader,
def readTreeRef()(using Context): TermRef =
readType().asInstanceOf[TermRef]

/** Under pureFunctions, map all function types to impure function types,
* unless the unpickled class was also compiled with pureFunctions.
*/
private def postProcessFunction(tp: Type)(using Context): Type =
if withPureFuns then tp else tp.adaptFunctionTypeUnderPureFuns

// ------ Reading definitions -----------------------------------------------------

private def nothingButMods(end: Addr): Boolean =
Expand Down Expand Up @@ -1240,8 +1232,7 @@ class TreeUnpickler(reader: TastyReader,
case SINGLETONtpt =>
SingletonTypeTree(readTree())
case BYNAMEtpt =>
val arg = readTpt()
ByNameTypeTree(if withPureFuns then arg else arg.adaptByNameArgUnderPureFuns)
ByNameTypeTree(readTpt())
case NAMEDARG =>
NamedArg(readName(), readTree())
case EXPLICITtpt =>
Expand Down Expand Up @@ -1453,7 +1444,7 @@ class TreeUnpickler(reader: TastyReader,
val args = until(end)(readTpt())
val tree = untpd.AppliedTypeTree(tycon, args)
val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes))
tree.withType(postProcessFunction(ownType))
tree.withType(ownType)
case ANNOTATEDtpt =>
Annotated(readTpt(), readTree())
case LAMBDAtpt =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.annotation.switch
import reporting._
import cc.{adaptFunctionTypeUnderPureFuns, adaptByNameArgUnderPureFuns}

object Scala2Unpickler {

Expand Down Expand Up @@ -824,7 +823,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
}
val tycon = select(pre, sym)
val args = until(end, () => readTypeRef())
if (sym == defn.ByNameParamClass2x) ExprType(args.head.adaptByNameArgUnderPureFuns)
if (sym == defn.ByNameParamClass2x) ExprType(args.head)
else if (ctx.settings.scalajs.value && args.length == 2 &&
sym.owner == JSDefinitions.jsdefn.ScalaJSJSPackageClass && sym == JSDefinitions.jsdefn.PseudoUnionClass) {
// Treat Scala.js pseudo-unions as real unions, this requires a
Expand All @@ -833,7 +832,6 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
}
else if args.nonEmpty then
tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly)))
.adaptFunctionTypeUnderPureFuns
else if (sym.typeParams.nonEmpty) tycon.EtaExpand(sym.typeParams)
else tycon
case TYPEBOUNDStpe =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
try changePrec(GlobalPrec)(toText(arg) ~ "^" ~ toTextCaptureSet(captureSet))
catch case ex: IllegalCaptureRef => toTextAnnot
if annot.symbol.maybeOwner == defn.RetainsAnnot
&& Feature.ccEnabled && Config.printCaptureSetsAsPrefix && !printDebug
&& Feature.ccEnabled && !printDebug
then toTextRetainsAnnot
else toTextAnnot
case EmptyTree =>
Expand Down
Loading