Skip to content

Commit

Permalink
handle overloaded methods
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Oct 11, 2023
1 parent c6366fd commit 553b669
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package scala.meta.internal.metals

import scala.collection.mutable.ListBuffer
import scala.util.Success
import scala.util.Try

Expand Down Expand Up @@ -30,21 +31,60 @@ class ScaladocDefinitionProvider(
buffer <- buffers.get(path)
position <- params.getPosition().toMeta(Input.String(buffer))
symbol <- extractScalaDocLinkAtPos(buffer, position)
scalaMetaSymbols = symbol.toScalaMetaSymbols(
getPackageAndThisSymbols(path, position)
)
contextSymbols = getContext(path, position)
scalaMetaSymbols = symbol.toScalaMetaSymbols(contextSymbols)
_ = scribe.debug(
s"looking for definition for scaladoc symbol: $symbol considering alternatives: ${scalaMetaSymbols
.map(_.showSymbol)
.mkString(", ")}"
)
definitionResult <- scalaMetaSymbols.collectFirst { sym =>
Try(destinationProvider.fromSymbol(sym, Some(path))) match {
case Success(Some(value)) => value
search(sym, path) match {
case Some(value) => value
}
}
} yield definitionResult
}

private def search(symbol: ScalaDocLinkSymbol, path: AbsolutePath) =
symbol match {
case method: MethodSymbol => findAllOverLoadedMethods(method, path)
case StringSymbol(symbol) =>
Try(destinationProvider.fromSymbol(symbol, Some(path))).toOption.flatten
.filter(_.symbol == symbol)
}

private def findAllOverLoadedMethods(
method: MethodSymbol,
path: AbsolutePath,
) = {
var ident: Int = 0
val results: ListBuffer[DefinitionResult] = new ListBuffer
var ok: Boolean = true
while (ok) {
val currentSymbol = method.symbol(ident)
Try(
destinationProvider.fromSymbol(currentSymbol, Some(path))
) match {
case Success(Some(value)) if value.symbol == currentSymbol =>
ident += 1
results.addOne(value)
case _ => ok = false
}
}

if (results.isEmpty) None
else
Some(
new DefinitionResult(
results.toList.flatMap(_.locations.asScala).asJava,
results.head.symbol,
None,
None,
)
)
}

private def extractScalaDocLinkAtPos(
buffer: String,
position: Position,
Expand All @@ -59,7 +99,7 @@ class ScaladocDefinitionProvider(
symbol <- ScalaDocLink.atOffset(comment.text, offset)
} yield symbol

private def getPackageAndThisSymbols(
private def getContext(
path: AbsolutePath,
pos: Position,
): ContextSymbols = {
Expand Down Expand Up @@ -108,7 +148,9 @@ class ScaladocDefinitionProvider(
}

case class ScalaDocLink(value: String) {
def toScalaMetaSymbols(contextSymbols: => ContextSymbols): List[String] =
def toScalaMetaSymbols(
contextSymbols: => ContextSymbols
): List[ScalaDocLinkSymbol] =
if (value.isEmpty()) List.empty
else {
val symbol = symbolWithFixedPackages
Expand All @@ -127,15 +169,33 @@ case class ScalaDocLink(value: String) {
}
}

symbol.last match {
case '#' | '.' | '/' => withPrefixes
case '$' =>
withPrefixes.flatMap(sym =>
List(s"${sym.dropRight(1)}.", s"${sym.dropRight(1)}().")
)
case '!' => withPrefixes.flatMap(sym => List(s"${sym.dropRight(1)}#"))
case _ =>
withPrefixes.flatMap(sym => List(s"$sym#", s"$sym.", s"$sym()."))
List(symbol.indexOf("("), symbol.indexOf("[")).filter(_ >= 0) match {
case Nil =>
symbol.last match {
case '#' | '.' | '/' => withPrefixes.map(StringSymbol(_))
case '$' =>
withPrefixes.flatMap(sym =>
List(
StringSymbol(s"${sym.dropRight(1)}."),
MethodSymbol(s"${sym.dropRight(1)}"),
)
)
case '!' =>
withPrefixes.flatMap(sym =>
List(StringSymbol(s"${sym.dropRight(1)}#"))
)
case _ =>
withPrefixes.flatMap(sym =>
List(
StringSymbol(s"$sym#"),
StringSymbol(s"$sym."),
MethodSymbol(sym),
)
)
}
case list =>
val toDrop = symbol.length - list.min
withPrefixes.flatMap(sym => List(MethodSymbol(sym.dropRight(toDrop))))
}
}

Expand Down Expand Up @@ -185,3 +245,18 @@ object ContextSymbols {
def empty: ContextSymbols = ContextSymbols(None, None)

}

sealed trait ScalaDocLinkSymbol {
def showSymbol: String
}
case class StringSymbol(symbol: String) extends ScalaDocLinkSymbol {
override def showSymbol: String = symbol
}
case class MethodSymbol(prefixSymbol: String) extends ScalaDocLinkSymbol {
def symbol(i: Int): String =
i match {
case 0 => s"$prefixSymbol()."
case _ => s"$prefixSymbol(+$i)."
}
override def showSymbol: String = s"$prefixSymbol(+n)."
}
51 changes: 45 additions & 6 deletions tests/unit/src/test/scala/tests/DefinitionLspSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,13 @@ class DefinitionLspSuite extends BaseLspSuite("definition") {
)
_ = client.messageRequests.clear()
_ <- server.didOpen("a/src/main/scala/a/Main.scala")
definition <- server.definition(
locations <- server.definition(
"a/src/main/scala/a/Main.scala",
testCase,
workspace,
)
_ = assert(definition.nonEmpty)
_ = assert(definition.head.getUri().endsWith("scala/Double.scala"))
_ = assert(locations.nonEmpty)
_ = assert(locations.head.getUri().endsWith("scala/Double.scala"))
} yield ()
}

Expand Down Expand Up @@ -667,13 +667,52 @@ class DefinitionLspSuite extends BaseLspSuite("definition") {
)
_ = client.messageRequests.clear()
_ <- server.didOpen("a/src/main/scala/a/Main.scala")
definition <- server.definition(
locations <- server.definition(
"a/src/main/scala/a/Main.scala",
testCase,
workspace,
)
_ = assert(definition.nonEmpty)
_ = assert(definition.head.getUri().endsWith("a/Main.scala"))
_ = assert(locations.nonEmpty)
_ = assert(locations.head.getUri().endsWith("a/Main.scala"))
} yield ()
}

test("scaladoc-find-all-overridden-methods") {
val testCase =
"""|package a.internal
|
|object O {
| class A {
| /**
| * Calls [[fo@@o]]
| */
| def f: Int = g
| def foo: Int = ???
| def foo(i: Int): Int = ???
| def foo(str: String, i: Int): Int = ???
| }
|}
|""".stripMargin
for {
_ <- initialize(
s"""
|/metals.json
|{
| "a": { }
|}
|/a/src/main/scala/a/Main.scala
|${testCase.replace("@@", "")}
|""".stripMargin
)
_ = client.messageRequests.clear()
_ <- server.didOpen("a/src/main/scala/a/Main.scala")
locations <- server.definition(
"a/src/main/scala/a/Main.scala",
testCase,
workspace,
)
_ = assert(locations.length == 3)
_ = assert(locations.forall(_.getUri().endsWith("a/Main.scala")))
} yield ()
}

Expand Down

0 comments on commit 553b669

Please sign in to comment.