Skip to content

Commit

Permalink
improvement: Support completions for implicit classes
Browse files Browse the repository at this point in the history
Previously, we would only automatically suggest extension methods, not implicit classes.

Now, we also properly suggest implicit classes.

We could follow up with support for Scala 2 also.
  • Loading branch information
tgodzik committed Dec 1, 2023
1 parent fdb0a99 commit 69fb9a1
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import scala.meta.internal.metals.ReportContext
import scala.meta.pc.*

import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Flags
import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.Symbols.*

Expand All @@ -21,7 +22,12 @@ class CompilerSearchVisitor(
val logger: Logger = Logger.getLogger(classOf[CompilerSearchVisitor].getName)

private def isAccessible(sym: Symbol): Boolean = try
sym != NoSymbol && sym.isPublic && sym.isStatic
sym != NoSymbol && sym.isPublic && sym.isStatic || {
val owner = sym.maybeOwner
owner != NoSymbol && owner.isClass &&
owner.is(Flags.Implicit) &&
owner.isStatic && owner.isPublic
}
catch
case err: AssertionError =>
logger.log(Level.WARNING, err.getMessage())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,20 @@ object SemanticdbSymbols:
// however in scalac this method is defined only in `module Files`
if typeSym.is(JavaDefined) then
typeSym :: owner.info.decl(termName(value)).symbol :: Nil
/**
* Looks like decl doesn't work for:
* package a:
* implicit class A (i: Int):
* def inc = i + 1
*/
else if typeSym == NoSymbol then
val searched = typeName(value)
owner.info.allMembers
.find(_.name == searched)
.map(_.symbol)
.toList
else typeSym :: Nil
end if
case Descriptor.Term(value) =>
val outSymbol = owner.info.decl(termName(value)).symbol
if outSymbol.exists
Expand Down Expand Up @@ -91,6 +104,8 @@ object SemanticdbSymbols:
.map(_.symbol)
.filter(sym => symbolName(sym) == s)
.toList
end match
end tryMember

parentSymbol.flatMap(tryMember)
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class CompletionProvider(

def mkItemWithImports(
v: CompletionValue.Workspace | CompletionValue.Extension |
CompletionValue.Interpolator
CompletionValue.Interpolator | CompletionValue.ImplicitClass
) =
val sym = v.symbol
path match
Expand Down Expand Up @@ -273,7 +273,8 @@ class CompletionProvider(
end mkItemWithImports

completion match
case v: (CompletionValue.Workspace | CompletionValue.Extension) =>
case v: (CompletionValue.Workspace | CompletionValue.Extension |
CompletionValue.ImplicitClass) =>
mkItemWithImports(v)
case v: CompletionValue.Interpolator if v.isWorkspace || v.isExtension =>
mkItemWithImports(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ object CompletionValue:
override def description(printer: MetalsPrinter)(using Context): String =
s"${printer.completionSymbol(symbol)} (extension)"

/**
* CompletionValue for old implicit classes methods via SymbolSearch
*/
case class ImplicitClass(
label: String,
symbol: Symbol,
override val snippetSuffix: CompletionSuffix,
override val importSymbol: Symbol,
) extends Symbolic:
override def completionItemKind(using Context): CompletionItemKind =
CompletionItemKind.Method
override def description(printer: MetalsPrinter)(using Context): String =
s"${printer.completionSymbol(symbol)} (implicit)"

/**
* @param shortenedNames shortened type names by `Printer`. This field should be used for autoImports
* @param start Starting position of the completion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,14 +595,35 @@ class Completions(
Some(search.search(query, buildTargetIdentifier, visitor))
case CompletionKind.Members =>
val visitor = new CompilerSearchVisitor(sym =>
if sym.is(ExtensionMethod) &&
def isExtensionMethod = sym.is(ExtensionMethod) &&
qualType.widenDealias <:< sym.extensionParam.info.widenDealias
then

def isImplicitClass(owner: Symbol) =
val constructorParam =
owner.info.allMembers
.find(_.symbol.isAllOf(Flags.PrivateParamAccessor))
.map(_.info)
owner.isClass && owner.is(Flags.Implicit) &&
constructorParam.exists(p =>
qualType.widenDealias <:< p.widenDealias
)
end isImplicitClass

def isImplicitClassMethod = sym.is(Flags.Method) &&
isImplicitClass(sym.maybeOwner)

if isExtensionMethod then
completionsWithSuffix(
sym,
sym.decodedName,
CompletionValue.Extension(_, _, _),
).map(visit).forall(_ == true)
else if isImplicitClassMethod then
completionsWithSuffix(
sym,
sym.decodedName,
CompletionValue.ImplicitClass(_, _, _, sym.maybeOwner),
).map(visit).forall(_ == true)
else false,
)
Some(search.searchMethods(query, buildTargetIdentifier, visitor))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,27 @@ class ScalaToplevelMtags(
isCaseClassConstructor = true
)
)
def newExpectClassTemplate: Some[ExpectTemplate] =
def newExpectClassTemplate(
isImplicit: Boolean = false
): Some[ExpectTemplate] =
Some(
ExpectTemplate(
indent,
currentOwner,
false,
false,
isClassConstructor = true
isClassConstructor = true,
isImplicit = isImplicit
)
)
def newExpectPkgTemplate: Some[ExpectTemplate] =
Some(ExpectTemplate(indent, currentOwner, true, false))
def newExpectExtensionTemplate(owner: String): Some[ExpectTemplate] =
Some(ExpectTemplate(indent, owner, false, true))
def newExpectImplicitTemplate: Some[ExpectTemplate] =
Some(
ExpectTemplate(indent, currentOwner, false, false, isImplicit = true)
)
def newExpectIgnoreBody: Some[ExpectTemplate] =
Some(
ExpectTemplate(
Expand Down Expand Up @@ -181,13 +188,16 @@ class ScalaToplevelMtags(
val template = expectTemplate match {
case Some(expect) if expect.isCaseClassConstructor =>
newExpectCaseClassTemplate
case _ => newExpectClassTemplate
case Some(expect) =>
newExpectClassTemplate(expect.isImplicit)
case _ =>
newExpectClassTemplate(isImplicit = false)
}
loop(indent, isAfterNewline = false, currRegion, template)
// also covers extension methods because of `def` inside
case DEF
// extension group
if (includeMembers && dialect.allowExtensionMethods && currRegion.isExtension) =>
if (includeMembers && (dialect.allowExtensionMethods && currRegion.isExtension || currRegion.isImplicit)) =>
acceptTrivia()
newIdentifier.foreach { name =>
withOwner(currRegion.owner) {
Expand Down Expand Up @@ -306,7 +316,10 @@ class ScalaToplevelMtags(
case COLON if dialect.allowSignificantIndentation =>
(expectTemplate, nextIsNL()) match {
case (Some(expect), true) if needToParseBody(expect) =>
val next = expect.startIndentedRegion(currRegion)
val next = expect.startIndentedRegion(
currRegion,
isImplicitClass = expect.isImplicit
)
resetRegion(next)
scanner.nextToken()
loop(0, isAfterNewline = true, next, None)
Expand Down Expand Up @@ -340,7 +353,11 @@ class ScalaToplevelMtags(
case Some(expect)
if needToParseBody(expect) || needToParseExtension(expect) =>
val next =
expect.startInBraceRegion(currRegion, expect.isExtension)
expect.startInBraceRegion(
currRegion,
expect.isExtension,
expect.isImplicit
)
resetRegion(next)
scanner.nextToken()
loop(indent, isAfterNewline = false, next, None)
Expand All @@ -351,7 +368,7 @@ class ScalaToplevelMtags(
}
case RBRACE =>
val nextRegion = currRegion match {
case Region.InBrace(_, prev, _, _) => resetRegion(prev)
case Region.InBrace(_, prev, _, _, _) => resetRegion(prev)
case r => r
}
scanner.nextToken()
Expand Down Expand Up @@ -391,7 +408,7 @@ class ScalaToplevelMtags(
indent,
isAfterNewline = false,
currRegion.prev,
newExpectTemplate
expectTemplate // TODO this might fail tests
)
case COMMA =>
val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody)
Expand Down Expand Up @@ -422,7 +439,7 @@ class ScalaToplevelMtags(
val (shouldCreateClassTemplate, isAfterNewline) =
emitEnumCases(region, nextIsNewLine)
val nextExpectTemplate =
if (shouldCreateClassTemplate) newExpectClassTemplate
if (shouldCreateClassTemplate) newExpectClassTemplate()
else expectTemplate.filter(!_.isPackageBody)
loop(
indent,
Expand All @@ -431,6 +448,15 @@ class ScalaToplevelMtags(
if (scanner.curr.token == CLASS) newExpectCaseClassTemplate
else nextExpectTemplate
)
case IMPLICIT =>
scanner.nextToken()
loop(
indent,
isAfterNewline,
currRegion,
newExpectImplicitTemplate,
prevWasDot
)
case t =>
val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody)
scanner.nextToken()
Expand Down Expand Up @@ -832,7 +858,8 @@ object ScalaToplevelMtags {
isExtension: Boolean = false,
ignoreBody: Boolean = false,
isCaseClassConstructor: Boolean = false,
isClassConstructor: Boolean = false
isClassConstructor: Boolean = false,
isImplicit: Boolean = false
) {

/**
Expand All @@ -845,15 +872,29 @@ object ScalaToplevelMtags {
private def adjustRegion(r: Region): Region =
if (isPackageBody) r.prev else r

def startInBraceRegion(prev: Region, extension: Boolean = false): Region =
new Region.InBrace(owner, adjustRegion(prev), extension)
def startInBraceRegion(
prev: Region,
extension: Boolean = false,
isImplicitClass: Boolean = false
): Region =
new Region.InBrace(owner, adjustRegion(prev), extension, isImplicitClass)

def startInParenRegion(prev: Region, isCaseClass: Boolean): Region =
if (isCaseClass) Region.InParenCaseClass(owner, adjustRegion(prev), true)
else Region.InParenClass(owner, adjustRegion(prev))

def startIndentedRegion(prev: Region, extension: Boolean = false): Region =
new Region.Indented(owner, indent, adjustRegion(prev), extension)
def startIndentedRegion(
prev: Region,
extension: Boolean = false,
isImplicitClass: Boolean = false
): Region =
new Region.Indented(
owner,
indent,
adjustRegion(prev),
extension,
isImplicitClass: Boolean
)

}

Expand All @@ -863,6 +904,7 @@ object ScalaToplevelMtags {
def acceptMembers: Boolean
def produceSourceToplevel: Boolean = termOwner.isPackage
def isExtension: Boolean = false
def isImplicit: Boolean = false
val overloads: OverloadDisambiguator = new OverloadDisambiguator()
def termOwner: String =
owner // toplevel terms are wrapped into an artificial Object
Expand Down Expand Up @@ -898,39 +940,43 @@ object ScalaToplevelMtags {
owner: String,
prev: Region,
extension: Boolean = false,
override val termOwner: String
override val termOwner: String,
override val isImplicit: Boolean
) extends Region {
def this(
owner: String,
prev: Region,
extension: Boolean
) = this(owner, prev, extension, owner)
extension: Boolean,
isImplicit: Boolean
) = this(owner, prev, extension, owner, isImplicit)
def acceptMembers: Boolean =
owner.endsWith("/")

override def isExtension = extension

override val withTermOwner: String => InBrace = termOwner =>
InBrace(owner, prev, extension, termOwner)
InBrace(owner, prev, extension, termOwner, isImplicit)
}
final case class Indented(
owner: String,
exitIndent: Int,
prev: Region,
extension: Boolean = false,
override val termOwner: String
override val termOwner: String,
override val isImplicit: Boolean
) extends Region {
def this(
owner: String,
exitIndent: Int,
prev: Region,
extension: Boolean
) = this(owner, exitIndent, prev, extension, owner)
extension: Boolean,
isImplicit: Boolean
) = this(owner, exitIndent, prev, extension, owner, isImplicit)
def acceptMembers: Boolean =
owner.endsWith("/")
override def isExtension = extension
override val withTermOwner: String => Indented = termOwner =>
Indented(owner, exitIndent, prev, extension, termOwner)
Indented(owner, exitIndent, prev, extension, termOwner, isImplicit)
}

final case class InParenClass(
Expand Down
Loading

0 comments on commit 69fb9a1

Please sign in to comment.