Skip to content

Commit

Permalink
try inferring expected type
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Jul 1, 2024
1 parent fe9a85a commit 067d1da
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ public CompletableFuture<java.util.List<ReferencesResult>> references(References
return CompletableFuture.completedFuture(Collections.emptyList());
}

/**
* Returns the inferred expected type.
*/
public CompletableFuture<Optional<String>> inferExpectedType(OffsetParams params) {
return CompletableFuture.completedFuture(Optional.empty());
}

/**
* Return decoded and pretty printed TASTy content for .scala or .tasty file.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ case class ScalaPresentationCompiler(
.asJava
}

override def inferExpectedType(params: OffsetParams): CompletableFuture[ju.Optional[String]] =
compilerAccess.withInterruptableCompiler(Some(params))(
Optional.empty(),
params.token,
) { access =>
val driver = access.compiler()
new InferExpectedType(search, driver, params).infer().asJava
}

def shutdown(): Unit =
compilerAccess.shutdown()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package scala.meta.internal.pc

import scala.meta.internal.metals.ReportContext
import scala.meta.internal.mtags.MtagsEnrichments.*
import scala.meta.internal.pc.completions.InterCompletionType
import scala.meta.internal.pc.printer.MetalsPrinter
import scala.meta.pc.OffsetParams
import scala.meta.pc.SymbolSearch

import dotty.tools.dotc.interactive.Interactive
import dotty.tools.dotc.interactive.InteractiveDriver

class InferExpectedType(
search: SymbolSearch,
driver: InteractiveDriver,
params: OffsetParams
)(implicit rc: ReportContext):
val uri = params.uri

val sourceFile = CompilerInterfaces.toSource(params.uri, params.text())
driver.run(uri, sourceFile)

val ctx = driver.currentCtx
val pos = driver.sourcePosition(params)

def infer() =
driver.compilationUnits.get(uri) match
case Some(unit) =>
val path =
Interactive.pathTo(driver.openedTrees(uri), pos)(using ctx)
val newctx = ctx.fresh.setCompilationUnit(unit)
val tpdPath =
Interactive.pathTo(newctx.compilationUnit.tpdTree, pos.span)(using
newctx
)
val locatedCtx =
Interactive.contextOfPath(tpdPath)(using newctx)
val indexedCtx = IndexedContext(locatedCtx)
val printer = MetalsPrinter.standard(
indexedCtx,
search,
includeDefaultParam = MetalsPrinter.IncludeDefaultParam.ResolveLater,
)
InterCompletionType.inferType(path)(using newctx).map{
tpe => printer.tpe(tpe)
}
case None => None
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,15 @@ object InterCompletionType:

def inferType(path: List[Tree], span: Span)(using Context): Option[Type] =
path match
case Typed(expr, tpt) :: _ if expr.span.contains(span) && !tpt.tpe.isErroneous => Some(tpt.tpe)
case Block(_, expr) :: rest if expr.span.contains(span) =>
inferType(rest, span)
case If(cond, _, _) :: rest if !cond.span.contains(span) =>
inferType(rest, span)
case Bind(_, body) :: rest if body.span.contains(span) => inferType(rest, span)
case Alternative(_) :: rest => inferType(rest, span)
case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span)
case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span)
case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span)
case If(cond, _, _) :: rest if cond.span.contains(span) => Some(Symbols.defn.BooleanType)
case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
inferType(rest, span)
case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span)
Expand Down
209 changes: 209 additions & 0 deletions tests/cross/src/test/scala/tests/InferExpectedTypeSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package tests

import java.nio.file.Paths

import scala.meta.internal.metals.CompilerOffsetParams
import scala.meta.internal.metals.EmptyCancelToken
import scala.meta.internal.mtags.MtagsEnrichments._

import munit.TestOptions

class InferExpectedTypeSuite extends BasePCSuite {
override protected def ignoreScalaVersion: Option[IgnoreScalaVersion] = Some(
IgnoreScala2
)

def check(
name: TestOptions,
original: String,
expectedType: String,
fileName: String = "A.scala"
): Unit = test(name) {
presentationCompiler.restart()
val (code, offset) = params(original.replace("@@", "CURSOR@@"), fileName)
val offsetParams = CompilerOffsetParams(
Paths.get(fileName).toUri(),
code,
offset,
EmptyCancelToken
)
presentationCompiler.inferExpectedType(offsetParams).get().asScala match {
case Some(value) => assertNoDiff(value, expectedType)
case None => fail("Empty result.")
}
}

check(
"type-ascription",
"""|def doo = (@@ : Double)
|""".stripMargin,
"""|Double
|""".stripMargin
)
// some structures

check(
"try",
"""|val _: Int =
| try {
| @@
| } catch {
| case _ =>
| }
|""".stripMargin,
"""|Int
|""".stripMargin
)

check(
"try-catch",
"""|val _: Int =
| try {
| } catch {
| case _ => @@
| }
|""".stripMargin,
"""|Int
|""".stripMargin
)

check(
"if-condition",
"""|val _ = if @@ then 1 else 2
|""".stripMargin,
"""|Boolean
|""".stripMargin
)

check(
"inline-if",
"""|inline def o: Int = inline if ??? then @@ else ???
|""".stripMargin,
"""|Int
|""".stripMargin
)

// pattern matching

check(
"pattern-match",
"""|val _ =
| List(1) match
| case @@
|""".stripMargin,
"""|List[Int]
|""".stripMargin
)

check(
"bind",
"""|val _ =
| List(1) match
| case name @ @@
|""".stripMargin,
"""|List[Int]
|""".stripMargin
)

check(
"alternative",
"""|val _ =
| List(1) match
| case Nil | @@
|""".stripMargin,
"""|List[Int]
|""".stripMargin
)

check(
"unapply".ignore,
"""|val _ =
| List(1) match
| case @@ :: _ =>
|""".stripMargin,
"""|Int
|""".stripMargin
)

// generic functions

check(
"any-generic",
"""|val _ : List[Int] = identity(@@)
|""".stripMargin,
"""|List[Int]
|""".stripMargin
)

check(
"eq-generic",
"""|def eq[T](a: T, b: T): Boolean = ???
|val _ = eq(1, @@)
|""".stripMargin,
"""|Int
|""".stripMargin
)

check(
"flatmap".ignore,
"""|val _ : List[Int] = List().flatMap(_ => @@)
|""".stripMargin,
"""|IterableOnce[Int]
|""".stripMargin
)

check(
"for-comprehension".ignore,
"""|val _ : List[Int] =
| for {
| _ <- List("a", "b")
| } yield @@
|""".stripMargin,
"""|Int
|""".stripMargin
)

// bounds
check(
"any".ignore,
"""|trait Foo
|def foo[T](a: T): Boolean = ???
|val _ = foo(@@)
|""".stripMargin,
"""|<: Any
|""".stripMargin
)

check(
"bounds-1".ignore,
"""|trait Foo
|def foo[T <: Foo](a: Foo): Boolean = ???
|val _ = foo(@@)
|""".stripMargin,
"""|<: Foo
|""".stripMargin
)

check(
"bounds-2".ignore,
"""|trait Foo
|def foo[T :> Foo](a: Foo): Boolean = ???
|val _ = foo(@@)
|""".stripMargin,
"""|:> Foo
|""".stripMargin
)

check(
"bounds-3".ignore,
"""|trait A
|class B extends A
|class C extends B
|def roo[F >: C <: A](f: F) = ???
|val kjk = roo(@@)
|""".stripMargin,
"""|>: C <: A
|""".stripMargin
)

}

0 comments on commit 067d1da

Please sign in to comment.