diff --git a/metals/src/main/scala/scala/meta/internal/metals/BatchedFunction.scala b/metals/src/main/scala/scala/meta/internal/metals/BatchedFunction.scala index 0c6773ae679..3070cccbf25 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/BatchedFunction.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/BatchedFunction.scala @@ -1,12 +1,15 @@ package scala.meta.internal.metals +import java.util.concurrent.CancellationException import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.Promise +import scala.util.Failure +import scala.util.Success +import scala.util.Try import scala.util.control.NonFatal import scala.meta.internal.async.ConcurrentQueue @@ -22,6 +25,7 @@ final class BatchedFunction[A, B]( fn: Seq[A] => CancelableFuture[B], functionId: String, shouldLogQueue: Boolean = false, + default: Option[B] = None, )(implicit ec: ExecutionContext) extends (Seq[A] => Future[B]) with Function2[Seq[A], () => Unit, Future[B]] @@ -75,8 +79,17 @@ final class BatchedFunction[A, B]( } def cancelAll(): Unit = { - queue.clear() - unlock() + val requests = ConcurrentQueue.pollAll(queue) + requests.foreach(_.result.complete(defaultResult)) + cancelCurrent() + } + + def cancelCurrent(): Unit = { + lock.get() match { + case None => + case Some(promise) => + promise.tryFailure(new BatchedFunction.BatchedFunctionCancelation) + } } def currentFuture(): Future[B] = { @@ -97,22 +110,28 @@ final class BatchedFunction[A, B]( callback: () => Unit, ) - private val lock = new AtomicBoolean() + private val lock = new AtomicReference[Option[Promise[B]]](None) + private def unlock(): Unit = { - lock.set(false) + lock.set(None) if (!queue.isEmpty) { runAcquire() } } private def runAcquire(): Unit = { - if (!isPaused.get() && lock.compareAndSet(false, true)) { - runRelease() + lazy val promise = { + val p = Promise[B] + p.future.onComplete { _ => unlock() } + p + } + if (!isPaused.get() && lock.compareAndSet(None, Some(promise))) { + runRelease(promise) } else { // Do nothing, the submitted arguments will be handled // by a separate request. } } - private def runRelease(): Unit = { + private def runRelease(p: Promise[B]): Unit = { // Pre-condition: lock is acquired. // Pos-condition: // - lock is released @@ -128,24 +147,29 @@ final class BatchedFunction[A, B]( this.current.set(result) val resultF = for { result <- result.future - _ <- Future { - callbacks.foreach(cb => cb()) - } + _ <- Future { callbacks.foreach(cb => cb()) } } yield result - resultF.onComplete { response => - unlock() - requests.foreach(_.result.complete(response)) + resultF.onComplete(p.tryComplete) + p.future.onComplete { + case Failure(_: BatchedFunction.BatchedFunctionCancelation) => + result.cancel() + requests.foreach(_.result.complete(defaultResult)) + case result => + requests.foreach(_.result.complete(result)) } } else { - unlock() + p.tryFailure(new BatchedFunction.BatchedFunctionCancelation) } } catch { case NonFatal(e) => unlock() - requests.foreach(_.result.failure(e)) + requests.foreach(_.result.tryFailure(e)) scribe.error(s"Unexpected error releasing buffered job", e) } } + + def defaultResult: Try[B] = + default.map(Success(_)).getOrElse(Failure(new CancellationException)) } object BatchedFunction { @@ -153,6 +177,7 @@ object BatchedFunction { fn: Seq[A] => Future[B], functionId: String, shouldLogQueue: Boolean = false, + default: Option[B] = None, )(implicit ec: ExecutionContext ): BatchedFunction[A, B] = @@ -160,5 +185,7 @@ object BatchedFunction { fn.andThen(CancelableFuture(_)), functionId, shouldLogQueue, + default, ) + class BatchedFunctionCancelation extends RuntimeException } diff --git a/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala b/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala index 5bb0efb3263..62ee1e5736c 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala @@ -19,6 +19,7 @@ import scala.concurrent.ExecutionContextExecutorService import scala.concurrent.Future import scala.concurrent.Promise import scala.reflect.ClassTag +import scala.util.Success import scala.util.Try import scala.meta.internal.builds.MillBuildTool @@ -74,7 +75,6 @@ class BuildServerConnection private ( private val ongoingRequests = new MutableCancelable().addAll(initialConnection.cancelables) - private val ongoingCompilations = new MutableCancelable() def version: String = _version.get() @@ -155,7 +155,6 @@ class BuildServerConnection private ( def compile(params: CompileParams): CompletableFuture[CompileResult] = { register( server => server.buildTargetCompile(params), - isCompile = true, onFail = Some( ( new CompileResult(StatusCode.CANCELLED), @@ -300,14 +299,9 @@ class BuildServerConnection private ( override def cancel(): Unit = { if (cancelled.compareAndSet(false, true)) { ongoingRequests.cancel() - ongoingCompilations.cancel() } } - def cancelCompilations(): Unit = { - ongoingCompilations.cancel() - } - private def askUser( original: Future[BuildServerConnection.LauncherConnection] ): Future[BuildServerConnection.LauncherConnection] = { @@ -357,9 +351,8 @@ class BuildServerConnection private ( private def register[T: ClassTag]( action: MetalsBuildServer => CompletableFuture[T], onFail: => Option[(T, String)] = None, - isCompile: Boolean = false, ): CompletableFuture[T] = { - + val localCancelable = new MutableCancelable() def runWithCanceling( launcherConnection: BuildServerConnection.LauncherConnection ): Future[T] = { @@ -367,14 +360,14 @@ class BuildServerConnection private ( val cancelable = Cancelable { () => Try(resultFuture.cancel(true)) } - if (isCompile) ongoingCompilations.add(cancelable) - else ongoingRequests.add(cancelable) + ongoingRequests.add(cancelable) + localCancelable.add(cancelable) val result = resultFuture.asScala result.onComplete { _ => - if (isCompile) ongoingCompilations.remove(cancelable) - else ongoingRequests.remove(cancelable) + ongoingRequests.remove(cancelable) + localCancelable.remove(cancelable) } result } @@ -411,7 +404,14 @@ class BuildServerConnection private ( Future.failed(new MetalsBspException(name, t)) }) } - CancelTokens.future(_ => actionFuture) + + CancelTokens.future { token => + token.onCancel().asScala.onComplete { + case Success(java.lang.Boolean.TRUE) => localCancelable.cancel() + case _ => + } + actionFuture + } } def isBuildServerResponsive: Future[Option[Boolean]] = { diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala index 60e0762ec00..07f7ba8cdf4 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala @@ -32,12 +32,12 @@ final class Compilations( new BatchedFunction[ b.BuildTargetIdentifier, Map[BuildTargetIdentifier, b.CompileResult], - ](compile, "compileBatch", shouldLogQueue = true) + ](compile, "compileBatch", shouldLogQueue = true, Some(Map.empty)) private val cascadeBatch = new BatchedFunction[ b.BuildTargetIdentifier, Map[BuildTargetIdentifier, b.CompileResult], - ](compile, "cascadeBatch", shouldLogQueue = true) + ](compile, "cascadeBatch", shouldLogQueue = true, Some(Map.empty)) def pauseables: List[Pauseable] = List(compileBatch, cascadeBatch) private val isCompiling = TrieMap.empty[b.BuildTargetIdentifier, Boolean] @@ -115,15 +115,6 @@ final class Compilations( def cancel(): Unit = { cascadeBatch.cancelAll() compileBatch.cancelAll() - buildTargets.all - .flatMap { target => - buildTargets.buildServerOf(target.getId()) - } - .distinct - .foreach { conn => - conn.cancelCompilations() - } - } def recompileAll(): Future[Unit] = { diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala index 500ace26cfb..dc28139e724 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala @@ -2202,6 +2202,8 @@ class MetalsLspService( } def disconnectOldBuildServer(): Future[Unit] = { + compilations.cancel() + buildTargetClasses.cancel() diagnostics.reset() bspSession.foreach(connection => scribe.info(s"Disconnecting from ${connection.main.name} session...") diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala index a116fbafccf..c6bd1d79d91 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala @@ -24,7 +24,11 @@ final class BuildTargetClasses( : TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] = TrieMap.empty[b.BuildTargetIdentifier, b.JvmEnvironmentItem] val rebuildIndex: BatchedFunction[b.BuildTargetIdentifier, Unit] = - BatchedFunction.fromFuture(fetchClasses, "buildTargetClasses") + BatchedFunction.fromFuture( + fetchClasses, + "buildTargetClasses", + default = Some(()), + ) def classesOf(target: b.BuildTargetIdentifier): Classes = { index.getOrElse(target, new Classes) @@ -175,6 +179,10 @@ final class BuildTargetClasses( val name = NameTransformer.decode(names.last) descriptors.map(descriptor => Symbols.Global(prefix, descriptor(name))) } + + def cancel(): Unit = { + rebuildIndex.cancelAll() + } } sealed abstract class TestFramework(val canResolveChildren: Boolean) diff --git a/tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala b/tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala index 23f8e71fe29..e8917d435d2 100644 --- a/tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala +++ b/tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala @@ -1,10 +1,14 @@ package tests +import java.util.concurrent.Executors + import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.util.Success import scala.meta.internal.metals.BatchedFunction +import scala.meta.internal.metals.Cancelable +import scala.meta.internal.metals.CancelableFuture class BatchedFunctionSuite extends BaseSuite { test("batch") { @@ -95,11 +99,51 @@ class BatchedFunctionSuite extends BaseSuite { mkString.unpause() - assertDiffEqual(paused.value, None) - assertDiffEqual(paused2.value, None) - - val unpaused2 = mkString(List("a", "b")) - assertDiffEqual(unpaused2.value, Some(Success("ab"))) + for { + _ <- paused.failed + _ <- paused2.failed + res <- mkString(List("a", "b")) + _ = assertEquals(res, "ab") + } yield () + } + test("cancel2") { + val executorService = Executors.newFixedThreadPool(10) + val ec2 = ExecutionContext.fromExecutor(executorService) + var i = 1 + val stuckExample: BatchedFunction[String, String] = + new BatchedFunction( + (seq: Seq[String]) => { + seq.toList match { + case "loop" :: Nil => + val future = Future.apply { + while (i == 1) { + Thread.sleep(1) + } + "loop-result" + }(ec2) + CancelableFuture[String](future, Cancelable { () => i = 2 }) + case _ => + CancelableFuture[String]( + Future.successful("result"), + Cancelable.empty, + ) + } + }, + "stuck example", + default = Some("default"), + )(ec2) + val cancelled = stuckExample("loop") + assertEquals(i, 1) + assert(cancelled.value.isEmpty) + val normal = stuckExample("normal") + stuckExample.cancelCurrent() + for { + str <- cancelled + _ = assertEquals(i, 2) + _ = assertEquals(str, "default") + str <- normal + _ = assertEquals(str, "result") + } yield () } } diff --git a/tests/unit/src/test/scala/tests/BillLspSuite.scala b/tests/unit/src/test/scala/tests/BillLspSuite.scala index 5a50c4ad3ec..f3743ddbea1 100644 --- a/tests/unit/src/test/scala/tests/BillLspSuite.scala +++ b/tests/unit/src/test/scala/tests/BillLspSuite.scala @@ -10,6 +10,7 @@ import scala.meta.internal.metals.ServerCommands import scala.meta.io.AbsolutePath import bill._ +import ch.epfl.scala.bsp4j.StatusCode class BillLspSuite extends BaseLspSuite("bill") { @@ -222,4 +223,44 @@ class BillLspSuite extends BaseLspSuite("bill") { Bill.installGlobal(globalBsp.toNIO, "Bob") testSelectServerDialogue() } + + test("cancel-compile") { + val cancelPattern = + """Sending notification '\$\/cancelRequest'\s*Params: \{\s*\"id\": \"([0-9]+)\"\s*\}""".r + cleanWorkspace() + Bill.installWorkspace(workspace.toNIO, "Bill") + def trace = workspace.resolve(".metals/bsp.trace.json").readText + for { + _ <- initialize( + """|/src/com/App.scala + |object App { + | val x: Int = 1 + |} + |/.metals/bsp.trace.json + |""".stripMargin + ) + (compileReport, _) <- server.server.compilations + .compileFile( + workspace.resolve("src/com/App.scala") + ) + .zip { + // wait until the compilation start + while (!trace.contains(s"buildTarget/compile")) { + Thread.sleep(1) + } + server.executeCommand(ServerCommands.CancelCompile) + } + _ = assertEquals(compileReport.getStatusCode(), StatusCode.CANCELLED) + currentTrace = trace + cancelMatch = cancelPattern.findFirstMatchIn(currentTrace) + _ = assert(cancelMatch.nonEmpty, trace) + cancelId = cancelMatch.get.group(1) + _ = assert(currentTrace.contains(s"buildTarget/compile - ($cancelId)")) + compileReport <- server.server.compilations + .compileFile( + workspace.resolve("src/com/App.scala") + ) + _ = assertEquals(compileReport.getStatusCode(), StatusCode.OK) + } yield () + } }