Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle dependent context functions #18443

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
def isStructuralTermSelect(tree: Select) =
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
case defn.PolyFunctionOf(_) =>
case defn.FunctionTypeOfMethod(_) =>
false
case RefinedType(parent, rname, rinfo) =>
rname == tree.name || hasRefinement(parent)
Expand Down
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,24 @@ class Definitions {
// - .linkedClass: the ClassSymbol of the enumeration (class E)
sym.owner.linkedClass.typeRef

object FunctionTypeOfMethod {
/** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
* Extracts the method type type and apply info.
*/
def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
ft match
case RefinedType(parent, nme.apply, mt: MethodOrPoly)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd try to optimize it further by inlining the FunctionOf extractor and specializing it to this use case.
Also, maybe split the RefinedType match into two, one for MethodType refinements and the other for PolyType refinements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe split the RefinedType match into two, one for MethodType refinements and the other for PolyType refinements?

We would end up with these cases

ft match
  case RefinedType(parent, nme.apply, mt: MethodType)
  if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) =>
    Some(mt)
  case RefinedType(parent, nme.apply, pt: PolyType)
  if parent.derivesFrom(defn.PolyFunctionClass) =>
    Some(mt)
  ...

Note that in both cases we can have PolyFunctionClass because we now use PolyFunction for function types with erased parameters.

If it is for performance, we could do

ft match
  case RefinedType(parent, nme.apply, mt: MethodOrPoly)
  if parent.derivesFrom(defn.PolyFunctionClass) || (mt.isInstanceOf[MethodType] && isFunctionNType(parent)) =>
    Some(mt)
  ...

I pushed this last option.

Copy link
Contributor Author

@nicolasstucki nicolasstucki Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd try to optimize it further by inlining the FunctionOf extractor and specializing it to this use case.

Inlined, optimized, and aligned with #18418/#18486.

if parent.derivesFrom(defn.PolyFunctionClass) || (mt.isInstanceOf[MethodType] && isFunctionNType(parent)) =>
Some(mt)
case AppliedType(parent, targs) if isFunctionNType(ft) =>
val isContextual = ft.typeSymbol.name.isContextFunction
val methodType = if isContextual then ContextualMethodType else MethodType
Some(methodType(targs.init, targs.last))
case _ =>
None
}
}

object FunctionOf {
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =
val mt = MethodType.companion(isContextual, false)(args, resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ object ContextFunctionResults:
*/
def contextResultsAreErased(sym: Symbol)(using Context): Boolean =
def allErased(tp: Type): Boolean = tp.dealias match
case defn.ContextFunctionType(argTpes, resTpe) =>
argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe)
case ft @ defn.FunctionTypeOfMethod(mt: MethodType) if mt.isContextualMethod =>
mt.nonErasedParamCount == 0 && allErased(mt.resType)
case _ => true
contextResultCount(sym) > 0 && allErased(sym.info.finalResultType)

Expand All @@ -68,13 +68,13 @@ object ContextFunctionResults:
*/
def integrateContextResults(tp: Type, crCount: Int)(using Context): Type =
if crCount == 0 then tp
else tp match
else tp.dealias match
case ExprType(rt) =>
integrateContextResults(rt, crCount)
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount))
case defn.ContextFunctionType(argTypes, resType) =>
MethodType(argTypes, integrateContextResults(resType, crCount - 1))
case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod =>
mt.derivedLambdaType(resType = integrateContextResults(mt.resType, crCount - 1))

/** The total number of parameters of method `sym`, not counting
* erased parameters, but including context result parameters.
Expand All @@ -101,7 +101,7 @@ object ContextFunctionResults:
def recur(tp: Type, n: Int): Type =
if n == 0 then tp
else tp match
case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1)
case defn.FunctionTypeOfMethod(mt) => recur(mt.resType, n - 1)
recur(meth.info.finalResultType, depth)

/** Should selection `tree` be eliminated since it refers to an `apply`
Expand All @@ -115,8 +115,8 @@ object ContextFunctionResults:
else tree match
case Select(qual, name) =>
if name == nme.apply then
qual.tpe match
case defn.ContextFunctionType(_, _) =>
qual.tpe.nn.dealias match
case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod =>
integrateSelect(qual, n + 1)
case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs
integrateSelect(qual, n + 1)
Expand Down