diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 4aaef28b9e1e..9751e8272858 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -990,7 +990,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = { def isStructuralTermSelect(tree: Select) = def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match - case defn.PolyFunctionOf(_) => + case defn.FunctionTypeOfMethod(_) => false case RefinedType(parent, rname, rinfo) => rname == tree.name || hasRefinement(parent) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 22a49a760e57..846a1f68cb79 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1108,6 +1108,24 @@ class Definitions { // - .linkedClass: the ClassSymbol of the enumeration (class E) sym.owner.linkedClass.typeRef + object FunctionTypeOfMethod { + /** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`. + * Extracts the method type type and apply info. + */ + def unapply(ft: Type)(using Context): Option[MethodOrPoly] = { + ft match + case RefinedType(parent, nme.apply, mt: MethodOrPoly) + if parent.derivesFrom(defn.PolyFunctionClass) || (mt.isInstanceOf[MethodType] && isFunctionNType(parent)) => + Some(mt) + case AppliedType(parent, targs) if isFunctionNType(ft) => + val isContextual = ft.typeSymbol.name.isContextFunction + val methodType = if isContextual then ContextualMethodType else MethodType + Some(methodType(targs.init, targs.last)) + case _ => + None + } + } + object FunctionOf { def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type = val mt = MethodType.companion(isContextual, false)(args, resultType) diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index 01a77427698a..1b1d78182f0f 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -58,8 +58,8 @@ object ContextFunctionResults: */ def contextResultsAreErased(sym: Symbol)(using Context): Boolean = def allErased(tp: Type): Boolean = tp.dealias match - case defn.ContextFunctionType(argTpes, resTpe) => - argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe) + case ft @ defn.FunctionTypeOfMethod(mt: MethodType) if mt.isContextualMethod => + mt.nonErasedParamCount == 0 && allErased(mt.resType) case _ => true contextResultCount(sym) > 0 && allErased(sym.info.finalResultType) @@ -68,13 +68,13 @@ object ContextFunctionResults: */ def integrateContextResults(tp: Type, crCount: Int)(using Context): Type = if crCount == 0 then tp - else tp match + else tp.dealias match case ExprType(rt) => integrateContextResults(rt, crCount) case tp: MethodOrPoly => tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount)) - case defn.ContextFunctionType(argTypes, resType) => - MethodType(argTypes, integrateContextResults(resType, crCount - 1)) + case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod => + mt.derivedLambdaType(resType = integrateContextResults(mt.resType, crCount - 1)) /** The total number of parameters of method `sym`, not counting * erased parameters, but including context result parameters. @@ -101,7 +101,7 @@ object ContextFunctionResults: def recur(tp: Type, n: Int): Type = if n == 0 then tp else tp match - case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1) + case defn.FunctionTypeOfMethod(mt) => recur(mt.resType, n - 1) recur(meth.info.finalResultType, depth) /** Should selection `tree` be eliminated since it refers to an `apply` @@ -115,8 +115,8 @@ object ContextFunctionResults: else tree match case Select(qual, name) => if name == nme.apply then - qual.tpe match - case defn.ContextFunctionType(_, _) => + qual.tpe.nn.dealias match + case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod => integrateSelect(qual, n + 1) case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs integrateSelect(qual, n + 1)