diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java index 2dfbb93668a..04c84280e9f 100644 --- a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java @@ -119,6 +119,13 @@ public CompletableFuture> references(References return CompletableFuture.completedFuture(Collections.emptyList()); } + /** + * Returns the inferred expected type. + */ + public CompletableFuture> inferExpectedType(OffsetParams params) { + return CompletableFuture.completedFuture(Optional.empty()); + } + /** * Return decoded and pretty printed TASTy content for .scala or .tasty file. * diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 6433570af84..2aaa91fc46b 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -193,6 +193,15 @@ case class ScalaPresentationCompiler( .asJava } + override def inferExpectedType(params: OffsetParams): CompletableFuture[ju.Optional[String]] = + compilerAccess.withInterruptableCompiler(Some(params))( + Optional.empty(), + params.token, + ) { access => + val driver = access.compiler() + new InferExpectedType(search, driver, params).infer().asJava + } + def shutdown(): Unit = compilerAccess.shutdown() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/InferExpectedType.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/InferExpectedType.scala new file mode 100644 index 00000000000..edd3f40a377 --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/InferExpectedType.scala @@ -0,0 +1,47 @@ +package scala.meta.internal.pc + +import scala.meta.internal.metals.ReportContext +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.internal.pc.completions.InterCompletionType +import scala.meta.internal.pc.printer.MetalsPrinter +import scala.meta.pc.OffsetParams +import scala.meta.pc.SymbolSearch + +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver + +class InferExpectedType( + search: SymbolSearch, + driver: InteractiveDriver, + params: OffsetParams +)(implicit rc: ReportContext): + val uri = params.uri + + val sourceFile = CompilerInterfaces.toSource(params.uri, params.text()) + driver.run(uri, sourceFile) + + val ctx = driver.currentCtx + val pos = driver.sourcePosition(params) + + def infer() = + driver.compilationUnits.get(uri) match + case Some(unit) => + val path = + Interactive.pathTo(driver.openedTrees(uri), pos)(using ctx) + val newctx = ctx.fresh.setCompilationUnit(unit) + val tpdPath = + Interactive.pathTo(newctx.compilationUnit.tpdTree, pos.span)(using + newctx + ) + val locatedCtx = + Interactive.contextOfPath(tpdPath)(using newctx) + val indexedCtx = IndexedContext(locatedCtx) + val printer = MetalsPrinter.standard( + indexedCtx, + search, + includeDefaultParam = MetalsPrinter.IncludeDefaultParam.ResolveLater, + ) + InterCompletionType.inferType(path)(using newctx).map{ + tpe => printer.tpe(tpe) + } + case None => None diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/SingletonCompletions.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/SingletonCompletions.scala index 1fc66764952..5792a0078cf 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/completions/SingletonCompletions.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/completions/SingletonCompletions.scala @@ -64,10 +64,15 @@ object InterCompletionType: def inferType(path: List[Tree], span: Span)(using Context): Option[Type] = path match + case Typed(expr, tpt) :: _ if expr.span.contains(span) && !tpt.tpe.isErroneous => Some(tpt.tpe) case Block(_, expr) :: rest if expr.span.contains(span) => inferType(rest, span) - case If(cond, _, _) :: rest if !cond.span.contains(span) => - inferType(rest, span) + case Bind(_, body) :: rest if body.span.contains(span) => inferType(rest, span) + case Alternative(_) :: rest => inferType(rest, span) + case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span) + case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span) + case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span) + case If(cond, _, _) :: rest if cond.span.contains(span) => Some(Symbols.defn.BooleanType) case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span) case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span) diff --git a/tests/cross/src/test/scala/tests/InferExpectedTypeSuite.scala b/tests/cross/src/test/scala/tests/InferExpectedTypeSuite.scala new file mode 100644 index 00000000000..c78101fa751 --- /dev/null +++ b/tests/cross/src/test/scala/tests/InferExpectedTypeSuite.scala @@ -0,0 +1,209 @@ +package tests + +import java.nio.file.Paths + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.internal.metals.EmptyCancelToken +import scala.meta.internal.mtags.MtagsEnrichments._ + +import munit.TestOptions + +class InferExpectedTypeSuite extends BasePCSuite { + override protected def ignoreScalaVersion: Option[IgnoreScalaVersion] = Some( + IgnoreScala2 + ) + + def check( + name: TestOptions, + original: String, + expectedType: String, + fileName: String = "A.scala" + ): Unit = test(name) { + presentationCompiler.restart() + val (code, offset) = params(original.replace("@@", "CURSOR@@"), fileName) + val offsetParams = CompilerOffsetParams( + Paths.get(fileName).toUri(), + code, + offset, + EmptyCancelToken + ) + presentationCompiler.inferExpectedType(offsetParams).get().asScala match { + case Some(value) => assertNoDiff(value, expectedType) + case None => fail("Empty result.") + } + } + + check( + "type-ascription", + """|def doo = (@@ : Double) + |""".stripMargin, + """|Double + |""".stripMargin + ) +// some structures + + check( + "try", + """|val _: Int = + | try { + | @@ + | } catch { + | case _ => + | } + |""".stripMargin, + """|Int + |""".stripMargin + ) + + check( + "try-catch", + """|val _: Int = + | try { + | } catch { + | case _ => @@ + | } + |""".stripMargin, + """|Int + |""".stripMargin + ) + + check( + "if-condition", + """|val _ = if @@ then 1 else 2 + |""".stripMargin, + """|Boolean + |""".stripMargin + ) + + check( + "inline-if", + """|inline def o: Int = inline if ??? then @@ else ??? + |""".stripMargin, + """|Int + |""".stripMargin + ) + +// pattern matching + + check( + "pattern-match", + """|val _ = + | List(1) match + | case @@ + |""".stripMargin, + """|List[Int] + |""".stripMargin + ) + + check( + "bind", + """|val _ = + | List(1) match + | case name @ @@ + |""".stripMargin, + """|List[Int] + |""".stripMargin + ) + + check( + "alternative", + """|val _ = + | List(1) match + | case Nil | @@ + |""".stripMargin, + """|List[Int] + |""".stripMargin + ) + + check( + "unapply".ignore, + """|val _ = + | List(1) match + | case @@ :: _ => + |""".stripMargin, + """|Int + |""".stripMargin + ) + +// generic functions + + check( + "any-generic", + """|val _ : List[Int] = identity(@@) + |""".stripMargin, + """|List[Int] + |""".stripMargin + ) + + check( + "eq-generic", + """|def eq[T](a: T, b: T): Boolean = ??? + |val _ = eq(1, @@) + |""".stripMargin, + """|Int + |""".stripMargin + ) + + check( + "flatmap".ignore, + """|val _ : List[Int] = List().flatMap(_ => @@) + |""".stripMargin, + """|IterableOnce[Int] + |""".stripMargin + ) + + check( + "for-comprehension".ignore, + """|val _ : List[Int] = + | for { + | _ <- List("a", "b") + | } yield @@ + |""".stripMargin, + """|Int + |""".stripMargin + ) + +// bounds + check( + "any".ignore, + """|trait Foo + |def foo[T](a: T): Boolean = ??? + |val _ = foo(@@) + |""".stripMargin, + """|<: Any + |""".stripMargin + ) + + check( + "bounds-1".ignore, + """|trait Foo + |def foo[T <: Foo](a: Foo): Boolean = ??? + |val _ = foo(@@) + |""".stripMargin, + """|<: Foo + |""".stripMargin + ) + + check( + "bounds-2".ignore, + """|trait Foo + |def foo[T :> Foo](a: Foo): Boolean = ??? + |val _ = foo(@@) + |""".stripMargin, + """|:> Foo + |""".stripMargin + ) + + check( + "bounds-3".ignore, + """|trait A + |class B extends A + |class C extends B + |def roo[F >: C <: A](f: F) = ??? + |val kjk = roo(@@) + |""".stripMargin, + """|>: C <: A + |""".stripMargin + ) + +}