From bbba7b9411e9bf039f8f81b63a2cafad7cb1d8e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C4=99drzej=20Rochala?= <48657087+rochala@users.noreply.github.com> Date: Fri, 26 Jan 2024 11:57:25 +0100 Subject: [PATCH] Use comma counting for all signature help types (#19520) Fixes https://github.com/scalameta/metals/issues/6040 [Cherry-picked 1716bcd9dbefbef88def848c09768a698b6b9ed9] --- .../dotty/tools/dotc/util/Signatures.scala | 51 ++++--- .../signaturehelp/SignatureHelpSuite.scala | 126 ++++++++++++++++++ 2 files changed, 158 insertions(+), 19 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/util/Signatures.scala b/compiler/src/dotty/tools/dotc/util/Signatures.scala index 36d1723b9c06..9131f4f761a2 100644 --- a/compiler/src/dotty/tools/dotc/util/Signatures.scala +++ b/compiler/src/dotty/tools/dotc/util/Signatures.scala @@ -2,6 +2,7 @@ package dotty.tools.dotc package util import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.Positioned import dotty.tools.dotc.ast.untpd import dotty.tools.dotc.core.NameOps.* import dotty.tools.dotc.core.StdNames.nme @@ -99,7 +100,7 @@ object Signatures { findEnclosingApply(path, span) match case Apply(fun, params) => applyCallInfo(span, params, fun, false) case UnApply(fun, _, patterns) => unapplyCallInfo(span, fun, patterns) - case appliedTypeTree @ AppliedTypeTree(_, types) => appliedTypeTreeCallInfo(appliedTypeTree, types) + case appliedTypeTree @ AppliedTypeTree(_, types) => appliedTypeTreeCallInfo(span, appliedTypeTree, types) case tp @ TypeApply(fun, types) => applyCallInfo(span, types, fun, isTypeApply = true) case _ => (0, 0, Nil) @@ -154,13 +155,14 @@ object Signatures { * @param fun Function tree which is being applied */ private def appliedTypeTreeCallInfo( + span: Span, fun: tpd.Tree, types: List[tpd.Tree] )(using Context): (Int, Int, List[Signature]) = val typeName = fun.symbol.name.show val typeParams = fun.symbol.typeRef.typeParams.map(_.paramName.show).map(TypeParam.apply(_)) val denot = fun.denot.asSingleDenotation - val activeParameter = (types.length - 1) max 0 + val activeParameter = findCurrentParamIndex(types, span, typeParams.length - 1) val signature = Signature(typeName, List(typeParams), Some(typeName) , None, Some(denot)) (activeParameter, 0, List(signature)) @@ -237,21 +239,8 @@ object Signatures { case _ :: untpd.TypeApply(_, args) :: _ => args case _ => Nil - val currentParamsIndex = (untpdArgs.indexWhere(_.span.contains(span)) match - case -1 if untpdArgs.isEmpty => 0 - case -1 => - commaIndex(untpdArgs, span) match - // comma is before CURSOR, so we are in parameter b example: test("a", CURSOR) - case Some(index) if index <= span.end => untpdArgs.takeWhile(_.span.end < span.start).length - // comma is after CURSOR, so we are in parameter a example: test("a" CURSOR,) - case Some(index) => untpdArgs.takeWhile(_.span.start < span.end).length - 1 - // we are either in first or last parameter - case None => - if untpdArgs.head.span.start >= span.end then 0 - else untpdArgs.length - 1 max 0 - - case n => n - ) min (alternativeSymbol.paramSymss(safeParamssListIndex).length - 1) + val currentParamsIndex = + findCurrentParamIndex(untpdArgs, span, alternativeSymbol.paramSymss(safeParamssListIndex).length - 1) val pre = treeQualifier(fun) val alternativesWithTypes = alternatives.map(_.asSeenFrom(pre.tpe.widenTermRefExpr)) @@ -263,13 +252,37 @@ object Signatures { else (0, 0, Nil) + /** Finds current parameter index + * @param args List of currently applied arguments + * @param span The position of the cursor + * @param maxIndex The maximum index of the parameter in the current apply list + * + * @return Index of the current parameter + */ + private def findCurrentParamIndex(args: List[Positioned], span: Span, maxIndex: Int)(using Context): Int = + (args.indexWhere(_.span.contains(span)) match + case -1 if args.isEmpty => 0 + case -1 => + commaIndex(args, span) match + // comma is before CURSOR, so we are in parameter b example: test("a", CURSOR) + case Some(index) if index <= span.end => args.takeWhile(_.span.end < span.start).length + // comma is after CURSOR, so we are in parameter a example: test("a" CURSOR,) + case Some(index) => args.takeWhile(_.span.start < span.end).length - 1 + // we are either in first or last parameter + case None => + if args.head.span.start >= span.end then 0 + else args.length - 1 max 0 + + case n => n + ) min maxIndex + /** Parser ignores chars between arguments, we have to manually find the index of comma * @param untpdArgs List of applied untyped arguments * @param span The position of the cursor * * @return None if we are in first or last parameter, comma index otherwise */ - private def commaIndex(untpdArgs: List[untpd.Tree], span: Span)(using Context): Option[Int] = + private def commaIndex(untpdArgs: List[Positioned], span: Span)(using Context): Option[Int] = val previousArgIndex = untpdArgs.lastIndexWhere(_.span.end < span.end) for previousArg <- untpdArgs.lift(previousArgIndex) @@ -301,7 +314,7 @@ object Signatures { val paramTypes = extractParamTypess(resultType, denot, patterns.size).flatten.map(stripAllAnnots) val paramNames = extractParamNamess(resultType, denot).flatten - val activeParameter = patterns.takeWhile(_.span.end < span.start).length min (paramTypes.length - 1) + val activeParameter = findCurrentParamIndex(patterns, span, paramTypes.length - 1) val unapplySignature = toUnapplySignature(denot.asSingleDenotation, paramNames, paramTypes).toList (activeParameter, 0, unapplySignature) diff --git a/presentation-compiler/test/dotty/tools/pc/tests/signaturehelp/SignatureHelpSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/signaturehelp/SignatureHelpSuite.scala index 896cbdb3cad2..01d5e03b6c1e 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/signaturehelp/SignatureHelpSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/signaturehelp/SignatureHelpSuite.scala @@ -1373,3 +1373,129 @@ class SignatureHelpSuite extends BaseSignatureHelpSuite: | test(@@""".stripMargin, "" ) + + @Test def `type-var-position` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[A@@, C]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-1` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[@@A, C]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-2` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[A@@ + | , C]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-3` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[A, C@@]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-4` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[A,@@ C]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-5` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[A, @@C]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-6` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[ + | A@@, + | C + | ]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-7` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[ + | A, + | C@@ + | ]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-8` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[ + | A, + | @@C + | ]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-9` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[ + | A, + | C + | @@]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + ) + + @Test def `type-var-position-10` = + check( + """|trait Test[A, B]: + | def doThing[C](f: B => Test[@@ + | A, + | C + | ]) = ??? + |""".stripMargin, + """|Test[A, B]: Test + | ^ + |""".stripMargin + )