From 69fb9a177754d267c66ac37321b4709372e49b7e Mon Sep 17 00:00:00 2001 From: tgodzik Date: Fri, 1 Dec 2023 11:14:21 +0100 Subject: [PATCH] improvement: Support completions for implicit classes Previously, we would only automatically suggest extension methods, not implicit classes. Now, we also properly suggest implicit classes. We could follow up with support for Scala 2 also. --- .../internal/pc/CompilerSearchVisitor.scala | 8 +- .../meta/internal/pc/SemanticdbSymbols.scala | 15 ++ .../pc/completions/CompletionProvider.scala | 5 +- .../pc/completions/CompletionValue.scala | 14 ++ .../internal/pc/completions/Completions.scala | 25 ++- .../internal/mtags/ScalaToplevelMtags.scala | 90 ++++++--- .../pc/CompletionExtensionMethodSuite.scala | 186 ++++++++++++++++++ .../mtest/src/main/scala/tests/PCSuite.scala | 2 +- .../feature/CompletionCrossLspSuite.scala | 48 +++++ 9 files changed, 365 insertions(+), 28 deletions(-) diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/CompilerSearchVisitor.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/CompilerSearchVisitor.scala index aef1b57b89e..6cdc980d473 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/CompilerSearchVisitor.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/CompilerSearchVisitor.scala @@ -10,6 +10,7 @@ import scala.meta.internal.metals.ReportContext import scala.meta.pc.* import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags import dotty.tools.dotc.core.Names.* import dotty.tools.dotc.core.Symbols.* @@ -21,7 +22,12 @@ class CompilerSearchVisitor( val logger: Logger = Logger.getLogger(classOf[CompilerSearchVisitor].getName) private def isAccessible(sym: Symbol): Boolean = try - sym != NoSymbol && sym.isPublic && sym.isStatic + sym != NoSymbol && sym.isPublic && sym.isStatic || { + val owner = sym.maybeOwner + owner != NoSymbol && owner.isClass && + owner.is(Flags.Implicit) && + owner.isStatic && owner.isPublic + } catch case err: AssertionError => logger.log(Level.WARNING, err.getMessage()) diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/SemanticdbSymbols.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/SemanticdbSymbols.scala index c8749645d91..92429cbdd73 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/SemanticdbSymbols.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/SemanticdbSymbols.scala @@ -48,7 +48,20 @@ object SemanticdbSymbols: // however in scalac this method is defined only in `module Files` if typeSym.is(JavaDefined) then typeSym :: owner.info.decl(termName(value)).symbol :: Nil + /** + * Looks like decl doesn't work for: + * package a: + * implicit class A (i: Int): + * def inc = i + 1 + */ + else if typeSym == NoSymbol then + val searched = typeName(value) + owner.info.allMembers + .find(_.name == searched) + .map(_.symbol) + .toList else typeSym :: Nil + end if case Descriptor.Term(value) => val outSymbol = owner.info.decl(termName(value)).symbol if outSymbol.exists @@ -91,6 +104,8 @@ object SemanticdbSymbols: .map(_.symbol) .filter(sym => symbolName(sym) == s) .toList + end match + end tryMember parentSymbol.flatMap(tryMember) try diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionProvider.scala index 03318c82e8e..f04f3f01e6a 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionProvider.scala @@ -224,7 +224,7 @@ class CompletionProvider( def mkItemWithImports( v: CompletionValue.Workspace | CompletionValue.Extension | - CompletionValue.Interpolator + CompletionValue.Interpolator | CompletionValue.ImplicitClass ) = val sym = v.symbol path match @@ -273,7 +273,8 @@ class CompletionProvider( end mkItemWithImports completion match - case v: (CompletionValue.Workspace | CompletionValue.Extension) => + case v: (CompletionValue.Workspace | CompletionValue.Extension | + CompletionValue.ImplicitClass) => mkItemWithImports(v) case v: CompletionValue.Interpolator if v.isWorkspace || v.isExtension => mkItemWithImports(v) diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionValue.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionValue.scala index 81e7fd15ee0..1ad079405e7 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionValue.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/CompletionValue.scala @@ -128,6 +128,20 @@ object CompletionValue: override def description(printer: MetalsPrinter)(using Context): String = s"${printer.completionSymbol(symbol)} (extension)" + /** + * CompletionValue for old implicit classes methods via SymbolSearch + */ + case class ImplicitClass( + label: String, + symbol: Symbol, + override val snippetSuffix: CompletionSuffix, + override val importSymbol: Symbol, + ) extends Symbolic: + override def completionItemKind(using Context): CompletionItemKind = + CompletionItemKind.Method + override def description(printer: MetalsPrinter)(using Context): String = + s"${printer.completionSymbol(symbol)} (implicit)" + /** * @param shortenedNames shortened type names by `Printer`. This field should be used for autoImports * @param start Starting position of the completion diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/Completions.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/Completions.scala index 7d0b823b5c9..7a3ac0d3491 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/Completions.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/Completions.scala @@ -595,14 +595,35 @@ class Completions( Some(search.search(query, buildTargetIdentifier, visitor)) case CompletionKind.Members => val visitor = new CompilerSearchVisitor(sym => - if sym.is(ExtensionMethod) && + def isExtensionMethod = sym.is(ExtensionMethod) && qualType.widenDealias <:< sym.extensionParam.info.widenDealias - then + + def isImplicitClass(owner: Symbol) = + val constructorParam = + owner.info.allMembers + .find(_.symbol.isAllOf(Flags.PrivateParamAccessor)) + .map(_.info) + owner.isClass && owner.is(Flags.Implicit) && + constructorParam.exists(p => + qualType.widenDealias <:< p.widenDealias + ) + end isImplicitClass + + def isImplicitClassMethod = sym.is(Flags.Method) && + isImplicitClass(sym.maybeOwner) + + if isExtensionMethod then completionsWithSuffix( sym, sym.decodedName, CompletionValue.Extension(_, _, _), ).map(visit).forall(_ == true) + else if isImplicitClassMethod then + completionsWithSuffix( + sym, + sym.decodedName, + CompletionValue.ImplicitClass(_, _, _, sym.maybeOwner), + ).map(visit).forall(_ == true) else false, ) Some(search.searchMethods(query, buildTargetIdentifier, visitor)) diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala index c6cafdbb439..99c7cfae618 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala @@ -92,20 +92,27 @@ class ScalaToplevelMtags( isCaseClassConstructor = true ) ) - def newExpectClassTemplate: Some[ExpectTemplate] = + def newExpectClassTemplate( + isImplicit: Boolean = false + ): Some[ExpectTemplate] = Some( ExpectTemplate( indent, currentOwner, false, false, - isClassConstructor = true + isClassConstructor = true, + isImplicit = isImplicit ) ) def newExpectPkgTemplate: Some[ExpectTemplate] = Some(ExpectTemplate(indent, currentOwner, true, false)) def newExpectExtensionTemplate(owner: String): Some[ExpectTemplate] = Some(ExpectTemplate(indent, owner, false, true)) + def newExpectImplicitTemplate: Some[ExpectTemplate] = + Some( + ExpectTemplate(indent, currentOwner, false, false, isImplicit = true) + ) def newExpectIgnoreBody: Some[ExpectTemplate] = Some( ExpectTemplate( @@ -181,13 +188,16 @@ class ScalaToplevelMtags( val template = expectTemplate match { case Some(expect) if expect.isCaseClassConstructor => newExpectCaseClassTemplate - case _ => newExpectClassTemplate + case Some(expect) => + newExpectClassTemplate(expect.isImplicit) + case _ => + newExpectClassTemplate(isImplicit = false) } loop(indent, isAfterNewline = false, currRegion, template) // also covers extension methods because of `def` inside case DEF // extension group - if (includeMembers && dialect.allowExtensionMethods && currRegion.isExtension) => + if (includeMembers && (dialect.allowExtensionMethods && currRegion.isExtension || currRegion.isImplicit)) => acceptTrivia() newIdentifier.foreach { name => withOwner(currRegion.owner) { @@ -306,7 +316,10 @@ class ScalaToplevelMtags( case COLON if dialect.allowSignificantIndentation => (expectTemplate, nextIsNL()) match { case (Some(expect), true) if needToParseBody(expect) => - val next = expect.startIndentedRegion(currRegion) + val next = expect.startIndentedRegion( + currRegion, + isImplicitClass = expect.isImplicit + ) resetRegion(next) scanner.nextToken() loop(0, isAfterNewline = true, next, None) @@ -340,7 +353,11 @@ class ScalaToplevelMtags( case Some(expect) if needToParseBody(expect) || needToParseExtension(expect) => val next = - expect.startInBraceRegion(currRegion, expect.isExtension) + expect.startInBraceRegion( + currRegion, + expect.isExtension, + expect.isImplicit + ) resetRegion(next) scanner.nextToken() loop(indent, isAfterNewline = false, next, None) @@ -351,7 +368,7 @@ class ScalaToplevelMtags( } case RBRACE => val nextRegion = currRegion match { - case Region.InBrace(_, prev, _, _) => resetRegion(prev) + case Region.InBrace(_, prev, _, _, _) => resetRegion(prev) case r => r } scanner.nextToken() @@ -391,7 +408,7 @@ class ScalaToplevelMtags( indent, isAfterNewline = false, currRegion.prev, - newExpectTemplate + expectTemplate // TODO this might fail tests ) case COMMA => val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody) @@ -422,7 +439,7 @@ class ScalaToplevelMtags( val (shouldCreateClassTemplate, isAfterNewline) = emitEnumCases(region, nextIsNewLine) val nextExpectTemplate = - if (shouldCreateClassTemplate) newExpectClassTemplate + if (shouldCreateClassTemplate) newExpectClassTemplate() else expectTemplate.filter(!_.isPackageBody) loop( indent, @@ -431,6 +448,15 @@ class ScalaToplevelMtags( if (scanner.curr.token == CLASS) newExpectCaseClassTemplate else nextExpectTemplate ) + case IMPLICIT => + scanner.nextToken() + loop( + indent, + isAfterNewline, + currRegion, + newExpectImplicitTemplate, + prevWasDot + ) case t => val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody) scanner.nextToken() @@ -832,7 +858,8 @@ object ScalaToplevelMtags { isExtension: Boolean = false, ignoreBody: Boolean = false, isCaseClassConstructor: Boolean = false, - isClassConstructor: Boolean = false + isClassConstructor: Boolean = false, + isImplicit: Boolean = false ) { /** @@ -845,15 +872,29 @@ object ScalaToplevelMtags { private def adjustRegion(r: Region): Region = if (isPackageBody) r.prev else r - def startInBraceRegion(prev: Region, extension: Boolean = false): Region = - new Region.InBrace(owner, adjustRegion(prev), extension) + def startInBraceRegion( + prev: Region, + extension: Boolean = false, + isImplicitClass: Boolean = false + ): Region = + new Region.InBrace(owner, adjustRegion(prev), extension, isImplicitClass) def startInParenRegion(prev: Region, isCaseClass: Boolean): Region = if (isCaseClass) Region.InParenCaseClass(owner, adjustRegion(prev), true) else Region.InParenClass(owner, adjustRegion(prev)) - def startIndentedRegion(prev: Region, extension: Boolean = false): Region = - new Region.Indented(owner, indent, adjustRegion(prev), extension) + def startIndentedRegion( + prev: Region, + extension: Boolean = false, + isImplicitClass: Boolean = false + ): Region = + new Region.Indented( + owner, + indent, + adjustRegion(prev), + extension, + isImplicitClass: Boolean + ) } @@ -863,6 +904,7 @@ object ScalaToplevelMtags { def acceptMembers: Boolean def produceSourceToplevel: Boolean = termOwner.isPackage def isExtension: Boolean = false + def isImplicit: Boolean = false val overloads: OverloadDisambiguator = new OverloadDisambiguator() def termOwner: String = owner // toplevel terms are wrapped into an artificial Object @@ -898,39 +940,43 @@ object ScalaToplevelMtags { owner: String, prev: Region, extension: Boolean = false, - override val termOwner: String + override val termOwner: String, + override val isImplicit: Boolean ) extends Region { def this( owner: String, prev: Region, - extension: Boolean - ) = this(owner, prev, extension, owner) + extension: Boolean, + isImplicit: Boolean + ) = this(owner, prev, extension, owner, isImplicit) def acceptMembers: Boolean = owner.endsWith("/") override def isExtension = extension override val withTermOwner: String => InBrace = termOwner => - InBrace(owner, prev, extension, termOwner) + InBrace(owner, prev, extension, termOwner, isImplicit) } final case class Indented( owner: String, exitIndent: Int, prev: Region, extension: Boolean = false, - override val termOwner: String + override val termOwner: String, + override val isImplicit: Boolean ) extends Region { def this( owner: String, exitIndent: Int, prev: Region, - extension: Boolean - ) = this(owner, exitIndent, prev, extension, owner) + extension: Boolean, + isImplicit: Boolean + ) = this(owner, exitIndent, prev, extension, owner, isImplicit) def acceptMembers: Boolean = owner.endsWith("/") override def isExtension = extension override val withTermOwner: String => Indented = termOwner => - Indented(owner, exitIndent, prev, extension, termOwner) + Indented(owner, exitIndent, prev, extension, termOwner, isImplicit) } final case class InParenClass( diff --git a/tests/cross/src/test/scala/tests/pc/CompletionExtensionMethodSuite.scala b/tests/cross/src/test/scala/tests/pc/CompletionExtensionMethodSuite.scala index 1fdda5fb12e..acb514fb90d 100644 --- a/tests/cross/src/test/scala/tests/pc/CompletionExtensionMethodSuite.scala +++ b/tests/cross/src/test/scala/tests/pc/CompletionExtensionMethodSuite.scala @@ -21,6 +21,20 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + check( + "simple-old-syntax", + """|package example + | + |object Test: + | implicit class TestOps(a: Int): + | def testOps(b: Int): String = ??? + | + |def main = 100.test@@ + |""".stripMargin, + """|testOps(b: Int): String (implicit) + |""".stripMargin + ) + check( "simple2", """|package example @@ -36,6 +50,21 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { filter = _.contains("(extension)") ) + check( + "simple2-old-syntax", + """|package example + | + |object enrichments: + | implicit class TestOps(a: Int): + | def testOps(b: Int): String = ??? + | + |def main = 100.t@@ + |""".stripMargin, + """|testOps(b: Int): String (implicit) + |""".stripMargin, + filter = _.contains("(implicit)") + ) + check( "simple-empty", """|package example @@ -51,6 +80,21 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { filter = _.contains("(extension)") ) + check( + "simple-empty-old", + """|package example + | + |object enrichments: + | implicit class TestOps(a: Int): + | def testOps(b: Int): String = ??? + | + |def main = 100.@@ + |""".stripMargin, + """|testOps(b: Int): String (implicit) + |""".stripMargin, + filter = _.contains("(implicit)") + ) + check( "filter-by-type", """|package example @@ -68,6 +112,23 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { filter = _.contains("(extension)") ) + check( + "filter-by-type-old", + """|package example + | + |object enrichments: + | implicit class A(num: Int): + | def identity2: Int = num + 1 + | implicit class B(str: String): + | def identity: String = str + | + |def main = "foo".iden@@ + |""".stripMargin, + """|identity: String (implicit) + |""".stripMargin // incr won't be available + + ) + check( "filter-by-type-subtype", """|package example @@ -86,6 +147,24 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { filter = _.contains("(extension)") ) + check( + "filter-by-type-subtype-old", + """|package example + | + |class A + |class B extends A + | + |object enrichments: + | implicit class Test(a: A): + | def doSomething: A = a + | + |def main = (new B).do@@ + |""".stripMargin, + """|doSomething: A (implicit) + |""".stripMargin, + filter = _.contains("(implicit)") + ) + checkEdit( "simple-edit", """|package example @@ -108,6 +187,28 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + checkEdit( + "simple-edit-old", + """|package example + | + |object enrichments: + | implicit class A (num: Int): + | def incr: Int = num + 1 + | + |def main = 100.inc@@ + |""".stripMargin, + """|package example + | + |import example.enrichments.A + | + |object enrichments: + | implicit class A (num: Int): + | def incr: Int = num + 1 + | + |def main = 100.incr + |""".stripMargin + ) + checkEdit( "simple-edit-suffix", """|package example @@ -130,6 +231,28 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + checkEdit( + "simple-edit-suffix-old", + """|package example + | + |object enrichments: + | implicit class A (num: Int): + | def plus(other: Int): Int = num + other + | + |def main = 100.pl@@ + |""".stripMargin, + """|package example + | + |import example.enrichments.A + | + |object enrichments: + | implicit class A (num: Int): + | def plus(other: Int): Int = num + other + | + |def main = 100.plus($0) + |""".stripMargin + ) + // NOTE: In 3.1.3, package object name includes the whole path to file // eg. in 3.2.2 we get `A$package`, but in 3.1.3 `/some/path/to/file/A$package` check( @@ -146,6 +269,20 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + check( + "directly-in-pkg1-old".tag(IgnoreScalaVersion.forLessThan("3.2.2")), + """| + |package examples: + | implicit class A(num: Int): + | def incr: Int = num + 1 + | + |package examples2: + | def main = 100.inc@@ + |""".stripMargin, + """|incr: Int (implicit) + |""".stripMargin + ) + check( "directly-in-pkg2".tag(IgnoreScalaVersion.forLessThan("3.2.2")), """|package example: @@ -160,6 +297,20 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + check( + "directly-in-pkg2-old".tag(IgnoreScalaVersion.forLessThan("3.2.2")), + """|package examples: + | object X: + | def fooBar(num: Int) = num + 1 + | implicit class A (num: Int) { def incr: Int = num + 1 } + | + |package examples2: + | def main = 100.inc@@ + |""".stripMargin, + """|incr: Int (implicit) + |""".stripMargin + ) + checkEdit( "directly-in-pkg3".tag(IgnoreScalaVersion.forLessThan("3.2.2")), """|package example: @@ -177,6 +328,23 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + checkEdit( + "directly-in-pkg3".tag(IgnoreScalaVersion.forLessThan("3.2.2")), + """|package examples: + | implicit class A (num: Int) { def incr: Int = num + 1 } + | + |package examples2: + | def main = 100.inc@@ + |""".stripMargin, + """|import examples.A + |package examples: + | implicit class A (num: Int) { def incr: Int = num + 1 } + | + |package examples2: + | def main = 100.incr + |""".stripMargin + ) + check( "nested-pkg".tag(IgnoreScalaVersion.forLessThan("3.2.2")), """|package a: // some comment @@ -195,4 +363,22 @@ class CompletionExtensionMethodSuite extends BaseCompletionSuite { |""".stripMargin ) + check( + "nested-pkg".tag(IgnoreScalaVersion.forLessThan("3.2.2")), + """|package aa: // some comment + | package cc: + | implicit class A (num: Int): + | def increment2 = num + 2 + | implicit class A (num: Int): + | def increment = num + 1 + | + | + |package bb: + | def main: Unit = 123.incre@@ + |""".stripMargin, + """|increment: Int (implicit) + |increment2: Int (implicit) + |""".stripMargin + ) + } diff --git a/tests/mtest/src/main/scala/tests/PCSuite.scala b/tests/mtest/src/main/scala/tests/PCSuite.scala index 79f40b5b226..d74e0ec1d9e 100644 --- a/tests/mtest/src/main/scala/tests/PCSuite.scala +++ b/tests/mtest/src/main/scala/tests/PCSuite.scala @@ -100,7 +100,7 @@ trait PCSuite { case NonFatal(e) => println(s"warn: ${e.getMessage}") } - workspace.inputs(filename) = (code2, dialect) + workspace.inputs(file.toURI.toString()) = (code2, dialect) } } diff --git a/tests/slow/src/test/scala/tests/feature/CompletionCrossLspSuite.scala b/tests/slow/src/test/scala/tests/feature/CompletionCrossLspSuite.scala index d2a14f6bb1f..991aa42f4ec 100644 --- a/tests/slow/src/test/scala/tests/feature/CompletionCrossLspSuite.scala +++ b/tests/slow/src/test/scala/tests/feature/CompletionCrossLspSuite.scala @@ -106,6 +106,54 @@ class CompletionCrossLspSuite } yield () } + test("implicit-class") { + cleanWorkspace() + for { + _ <- initialize( + s"""/metals.json + |{ + | "a": { "scalaVersion": "${V.scala3}" } + |} + |/a/src/main/scala/a/B.scala + |package b + |implicit class B (num: Int): + | def plus(other: Int) = num + other + |/a/src/main/scala/a/A.scala + |package a + | + |object A { + | // @@ + |} + |""".stripMargin + ) + _ <- server.didOpen("a/src/main/scala/a/B.scala") + _ = assertNoDiagnostics() + _ <- assertCompletionEdit( + "1.p@@", + """|package a + | + |import b.B + | + |object A { + | 1.plus($0) + |} + |""".stripMargin, + filter = _.contains("plus"), + ) + _ <- assertCompletion( + "1.pl@@", + """|plus(other: Int): Int (implicit) + |""".stripMargin, + filter = _.contains("plus"), + ) + _ <- assertCompletion( + "\"plus is not available for string\".plu@@", + "", + filter = _.contains("plus"), + ) + } yield () + } + test("basic-scala3") { cleanWorkspace() for {