diff --git a/metals/src/main/scala/scala/meta/internal/bsp/BuildChange.scala b/metals/src/main/scala/scala/meta/internal/bsp/BuildChange.scala index 79c5cf949a6..72cde8f50b7 100644 --- a/metals/src/main/scala/scala/meta/internal/bsp/BuildChange.scala +++ b/metals/src/main/scala/scala/meta/internal/bsp/BuildChange.scala @@ -12,4 +12,5 @@ object BuildChange { case object Failed extends BuildChange case object Reconnected extends BuildChange case object Reloaded extends BuildChange + case object Cancelled extends BuildChange } diff --git a/metals/src/main/scala/scala/meta/internal/builds/BloopInstall.scala b/metals/src/main/scala/scala/meta/internal/builds/BloopInstall.scala index 44fed0a0699..9b4ba9bea41 100644 --- a/metals/src/main/scala/scala/meta/internal/builds/BloopInstall.scala +++ b/metals/src/main/scala/scala/meta/internal/builds/BloopInstall.scala @@ -1,7 +1,6 @@ package scala.meta.internal.builds import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.ExecutionContext import scala.concurrent.Future @@ -37,33 +36,23 @@ final class BloopInstall( override def toString: String = s"BloopInstall($workspace)" def runUnconditionally( - buildTool: BloopInstallProvider, - isImportInProcess: AtomicBoolean, + buildTool: BloopInstallProvider ): Future[WorkspaceLoadedStatus] = { - if (isImportInProcess.compareAndSet(false, true)) { - buildTool.bloopInstall( - workspace, - args => { - scribe.info(s"running '${args.mkString(" ")}'") - val process = - runArgumentsUnconditionally(buildTool, args, userConfig().javaHome) - process.foreach { e => - if (e.isFailed) { - // Record the exact command that failed to help troubleshooting. - scribe.error(s"$buildTool command failed: ${args.mkString(" ")}") - } + buildTool.bloopInstall( + workspace, + args => { + scribe.info(s"running '${args.mkString(" ")}'") + val process = + runArgumentsUnconditionally(buildTool, args, userConfig().javaHome) + process.foreach { e => + if (e.isFailed) { + // Record the exact command that failed to help troubleshooting. + scribe.error(s"$buildTool command failed: ${args.mkString(" ")}") } - process.onComplete(_ => isImportInProcess.set(false)) - process - }, - ) - } else { - Future - .successful { - languageClient.showMessage(ImportAlreadyRunning) - WorkspaceLoadedStatus.Dismissed } - } + process + }, + ) } private def runArgumentsUnconditionally( @@ -123,7 +112,6 @@ final class BloopInstall( def runIfApproved( buildTool: BloopInstallProvider, digest: String, - isImportInProcess: AtomicBoolean, ): Future[WorkspaceLoadedStatus] = synchronized { oldInstallResult(digest) match { @@ -133,7 +121,7 @@ final class BloopInstall( Future.successful(result) case _ => if (userConfig().shouldAutoImportNewProject) { - runUnconditionally(buildTool, isImportInProcess) + runUnconditionally(buildTool) } else { scribe.debug("Awaiting user response...") for { @@ -145,7 +133,7 @@ final class BloopInstall( ) installResult <- { if (userResponse.isYes) { - runUnconditionally(buildTool, isImportInProcess) + runUnconditionally(buildTool) } else { // Don't spam the user with requests during rapid build changes. notification.dismiss(2, TimeUnit.MINUTES) diff --git a/metals/src/main/scala/scala/meta/internal/metals/ConnectionProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/ConnectionProvider.scala index 91debfe58dd..1677f108a3b 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ConnectionProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ConnectionProvider.scala @@ -1,12 +1,14 @@ package scala.meta.internal.metals import java.nio.charset.Charset +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.ExecutionContextExecutorService import scala.concurrent.Future import scala.concurrent.Promise +import scala.util.Failure import scala.util.control.NonFatal import scala.meta.internal.bsp @@ -25,6 +27,7 @@ import scala.meta.internal.builds.Digest.Status import scala.meta.internal.builds.SbtBuildTool import scala.meta.internal.builds.ScalaCliBuildTool import scala.meta.internal.builds.ShellRunner +import scala.meta.internal.metals.Interruptable._ import scala.meta.internal.metals.Messages.IncompatibleBloopVersion import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.metals.doctor.Doctor @@ -296,34 +299,103 @@ class ConnectionProvider( } object Connect { + class RequestInfo(val request: ConnectRequest) { + val promise: Promise[BuildChange] = Promise() + val cancelPromise: Promise[Unit] = Promise() + def cancel(): Boolean = cancelPromise.trySuccess(()) + } + + @volatile private var currentRequest: Option[RequestInfo] = None + private val queue = new ConcurrentLinkedQueue[RequestInfo]() + + def getOngoingRequest(): Option[RequestInfo] = currentRequest + def connect[T](request: ConnectRequest): Future[BuildChange] = { - request match { - case Disconnect(shutdownBuildServer) => disconnect(shutdownBuildServer) - case Index(check) => index(check) - case ImportBuildAndIndex(session) => importBuildAndIndex(session) - case ConnectToSession(session) => connectToSession(session) - case CreateSession(shutdownBuildServer) => - createSession(shutdownBuildServer) - case GenerateBspConfigAndConnect(buildTool, shutdownServer) => - generateBspConfigAndConnect(buildTool, shutdownServer) - case BloopInstallAndConnect( - buildTool, - checksum, - forceImport, - shutdownServer, - ) => - bloopInstallAndConnect( - buildTool, - checksum, - forceImport, - shutdownServer, - ) + val info = addToQueue(request) + pollAndConnect() + info.promise.future + } + + private def addToQueue(request: ConnectRequest): RequestInfo = + synchronized { + val info = new RequestInfo(request) + val iter = queue.iterator() + while (iter.hasNext()) { + val curr = iter.next() + request.cancelCompare(iter.next().request) match { + case 1 => curr.cancel() + case -1 => info.cancel() + case _ => + } + } + queue.add(info) + // maybe cancel ongoing + currentRequest.foreach(ongoing => + if (request.cancelCompare(ongoing.request) == 1) ongoing.cancel() + ) + info + } + + private def pollAndConnect(): Unit = { + val optRequest = synchronized { + if (currentRequest.isEmpty) { + currentRequest = Option(queue.poll()) + currentRequest + } else None + } + + for (request <- optRequest) { + val cancelPromise = request.cancelPromise + val result = + if (cancelPromise.isCompleted) + Interruptable.successful(BuildChange.Cancelled) + else + request.request match { + case Disconnect(shutdownBuildServer) => + disconnect(shutdownBuildServer, cancelPromise) + case Index(check) => index(check, cancelPromise) + case ImportBuildAndIndex(session) => + importBuildAndIndex(session, cancelPromise) + case ConnectToSession(session) => + connectToSession(session, cancelPromise) + case CreateSession(shutdownBuildServer) => + createSession(shutdownBuildServer, cancelPromise) + case GenerateBspConfigAndConnect(buildTool, shutdownServer) => + generateBspConfigAndConnect( + buildTool, + shutdownServer, + cancelPromise, + ) + case BloopInstallAndConnect( + buildTool, + checksum, + forceImport, + shutdownServer, + ) => + bloopInstallAndConnect( + buildTool, + checksum, + forceImport, + shutdownServer, + cancelPromise, + ) + } + result.future.onComplete { res => + res match { + case Failure(CancelConnectException) => + request.promise.trySuccess(BuildChange.Cancelled) + case _ => request.promise.tryComplete(res) + } + currentRequest = None + pollAndConnect() + } } } private def disconnect( - shutdownBuildServer: Boolean - ): Future[BuildChange] = { + shutdownBuildServer: Boolean, + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = { def shutdownBsp(optMainBsp: Option[String]): Future[Boolean] = { optMainBsp match { case Some(BloopServers.name) => @@ -348,34 +420,43 @@ class ConnectionProvider( ) for { - _ <- scalaCli.stop() - optMainBsp <- bspSession match { + _ <- scalaCli.stop(storeLast = true).withInterrupt(cancelPromise) + optMainBsp <- (bspSession match { case None => Future.successful(None) case Some(session) => bspSession = None mainBuildTargetsData.resetConnections(List.empty) session.shutdown().map(_ => Some(session.main.name)) - } + }).withInterrupt(cancelPromise) _ <- - if (shutdownBuildServer) shutdownBsp(optMainBsp) - else Future.successful(()) + if (shutdownBuildServer) + shutdownBsp(optMainBsp).withInterrupt(cancelPromise) + else Interruptable.successful(()) } yield BuildChange.None } - private def index(check: () => Unit): Future[BuildChange] = - profiledIndexWorkspace(check).map(_ => BuildChange.None) + private def index( + check: () => Unit, + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = + profiledIndexWorkspace(check) + .map(_ => BuildChange.None) + .withInterrupt(cancelPromise) private def importBuildAndIndex( - session: BspSession - ): Future[BuildChange] = { + session: BspSession, + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = { val importedBuilds0 = timerProvider.timed("Imported build") { session.importBuilds() } for { - bspBuilds <- workDoneProgress.trackFuture( - Messages.importingBuild, - importedBuilds0, - ) + bspBuilds <- workDoneProgress + .trackFuture( + Messages.importingBuild, + importedBuilds0, + ) + .withInterrupt(cancelPromise) _ = { val idToConnection = bspBuilds.flatMap { bspBuild => val targets = @@ -386,7 +467,7 @@ class ConnectionProvider( saveProjectReferencesInfo(bspBuilds) } _ = compilers.cancel() - buildChange <- index(check) + buildChange <- index(check, cancelPromise) } yield buildChange } @@ -408,7 +489,10 @@ class ConnectionProvider( DelegateSetting.writeProjectRef(folder, projectRefs) } - private def connectToSession(session: BspSession): Future[BuildChange] = { + private def connectToSession( + session: BspSession, + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = { scribe.info( s"Connected to Build server: ${session.main.name} v${session.version}" ) @@ -419,7 +503,7 @@ class ConnectionProvider( bspSession = Some(session) isConnecting.set(false) for { - _ <- importBuildAndIndex(session) + _ <- importBuildAndIndex(session, cancelPromise) _ = buildToolProvider.buildTool.foreach( workspaceReload.persistChecksumStatus(Digest.Status.Installed, _) ) @@ -450,7 +534,10 @@ class ConnectionProvider( } } - def createSession(shutdownServer: Boolean): Future[BuildChange] = { + def createSession( + shutdownServer: Boolean, + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = { def compileAllOpenFiles: BuildChange => Future[BuildChange] = { case change if !change.isFailed => Future @@ -465,25 +552,25 @@ class ConnectionProvider( case other => Future.successful(other) } - val scalaCliPaths = scalaCli.paths - isConnecting.set(true) (for { - _ <- disconnect(shutdownServer) - maybeSession <- timerProvider.timed( - "Connected to build server", - true, - ) { - bspConnector.connect( - buildToolProvider.buildTool, - folder, - userConfig, - shellRunner, - ) - } + _ <- disconnect(shutdownServer, cancelPromise) + maybeSession <- timerProvider + .timed( + "Connected to build server", + true, + ) { + bspConnector.connect( + buildToolProvider.buildTool, + folder, + userConfig, + shellRunner, + ) + } + .withInterrupt(cancelPromise) result <- maybeSession match { case Some(session) => - val result = connectToSession(session) + val result = connectToSession(session, cancelPromise) session.mainConnection.onReconnection { newMainConn => val updSession = session.copy(main = newMainConn) connect(ConnectToSession(updSession)) @@ -492,19 +579,17 @@ class ConnectionProvider( } result case None => - Future.successful(BuildChange.None) + Interruptable.successful(BuildChange.None) } - _ <- Future.sequence( - scalaCliPaths - .collect { - case path if (!buildTargets.belongsToBuildTarget(path.toNIO)) => - scalaCli.start(path) - } - ) + _ <- scalaCli + .startForAllLastPaths(path => + !buildTargets.belongsToBuildTarget(path.toNIO) + ) + .withInterrupt(cancelPromise) _ = initTreeView() } yield result) .recover { case NonFatal(e) => - disconnect(false) + disconnect(false, cancelPromise) val message = "Failed to connect with build server, no functionality will work." val details = " See logs for more details." @@ -514,7 +599,7 @@ class ConnectionProvider( scribe.error(message, e) BuildChange.Failed } - .flatMap(compileAllOpenFiles) + .flatMap(compileAllOpenFiles(_).withInterrupt(cancelPromise)) .map { res => buildServerPromise.trySuccess(()) res @@ -524,23 +609,25 @@ class ConnectionProvider( private def generateBspConfigAndConnect( buildTool: BuildServerProvider, shutdownServer: Boolean, - ): Future[BuildChange] = { + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = { tables.buildTool.chooseBuildTool(buildTool.executableName) maybeChooseServer(buildTool.buildServerName, alreadySelected = false) for { _ <- - if (shutdownServer) disconnect(shutdownServer) - else Future.unit + if (shutdownServer) disconnect(shutdownServer, cancelPromise) + else Interruptable.successful(()) status <- buildTool .generateBspConfig( folder, args => bspConfigGenerator.runUnconditionally(buildTool, args), statusBar, ) + .withInterrupt(cancelPromise) shouldConnect = handleGenerationStatus(buildTool, status) status <- - if (shouldConnect) createSession(false) - else Future.successful(BuildChange.Failed) + if (shouldConnect) createSession(false, cancelPromise) + else Interruptable.successful(BuildChange.Failed) } yield status } @@ -575,30 +662,27 @@ class ConnectionProvider( false } - val isImportInProcess = new AtomicBoolean(false) - private def bloopInstallAndConnect( buildTool: BloopInstallProvider, checksum: String, forceImport: Boolean, shutdownServer: Boolean, - ): Future[BuildChange] = { + cancelPromise: Promise[Unit], + ): Interruptable[BuildChange] = { for { result <- { if (forceImport) bloopInstall.runUnconditionally( - buildTool, - isImportInProcess, + buildTool ) else bloopInstall.runIfApproved( buildTool, checksum, - isImportInProcess, ) - } + }.withInterrupt(cancelPromise) change <- { - if (result.isInstalled) createSession(shutdownServer) + if (result.isInstalled) createSession(shutdownServer, cancelPromise) else if (result.isFailed) { for { change <- @@ -614,13 +698,13 @@ class ConnectionProvider( // Connect nevertheless, many build import failures are caused // by resolution errors in one weird module while other modules // exported successfully. - createSession(shutdownServer) + createSession(shutdownServer, cancelPromise) } else { languageClient.showMessage(Messages.ImportProjectFailed) - Future.successful(BuildChange.Failed) + Interruptable.successful(BuildChange.Failed) } } yield change - } else Future.successful(BuildChange.None) + } else Interruptable.successful(BuildChange.None) } } yield change } @@ -630,21 +714,77 @@ class ConnectionProvider( sealed trait ConnectKind object SlowConnect extends ConnectKind -sealed trait ConnectRequest extends ConnectKind +sealed trait ConnectRequest extends ConnectKind { + + /** + * -1 cancel this + * 1 cancel other + * 0 queue + * @param other + * @return + */ + def cancelCompare(other: ConnectRequest): Int +} -case class Disconnect(shutdownBuildServer: Boolean) extends ConnectRequest -case class Index(check: () => Unit) extends ConnectRequest -case class ImportBuildAndIndex(bspSession: BspSession) extends ConnectRequest -case class ConnectToSession(bspSession: BspSession) extends ConnectRequest +case class Disconnect(shutdownBuildServer: Boolean) extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case _: Index => 0 + case _ => -1 + } +} +case class Index(check: () => Unit) extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case _: Disconnect => 0 + case _ => -1 + } +} +case class ImportBuildAndIndex(bspSession: BspSession) extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case (_: Index) | (_: ImportBuildAndIndex) => 1 + case _: Disconnect => 0 + case _ => -1 + } +} +case class ConnectToSession(bspSession: BspSession) extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case (_: Disconnect) | (_: Index) | (_: ConnectToSession) => 1 + case _ => -1 + } +} case class CreateSession(shutdownBuildServer: Boolean = false) - extends ConnectRequest + extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case (_: Disconnect) | (_: Index) | (_: ConnectToSession) | CreateSession( + false + ) => + 1 + case _ => -1 + } +} case class GenerateBspConfigAndConnect( buildTool: BuildServerProvider, shutdownServer: Boolean = false, -) extends ConnectRequest +) extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case BloopInstallAndConnect(_, _, _, true) if !shutdownServer => 0 + case _ => 1 + } +} case class BloopInstallAndConnect( buildTool: BloopInstallProvider, checksum: String, forceImport: Boolean, shutdownServer: Boolean, -) extends ConnectRequest +) extends ConnectRequest { + def cancelCompare(other: ConnectRequest): Int = + other match { + case GenerateBspConfigAndConnect(_, true) if !shutdownServer => 0 + case _ => 1 + } +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/Interruptable.scala b/metals/src/main/scala/scala/meta/internal/metals/Interruptable.scala new file mode 100644 index 00000000000..2b3d7cc3f2c --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/Interruptable.scala @@ -0,0 +1,56 @@ +package scala.meta.internal.metals + +import java.util.concurrent.CompletableFuture + +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.Promise + +import scala.meta.internal.metals.Interruptable.CancelConnectException + +class Interruptable[+T] private ( + futureIn: Future[T], + cancelPromise: Promise[Unit], +) extends CompletableFuture { + + def future(implicit executor: ExecutionContext): Future[T] = futureIn.map( + if (cancelPromise.isCompleted) throw CancelConnectException else _ + ) + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = { + cancelPromise.trySuccess(()) + true + } + + override def isCancelled(): Boolean = cancelPromise.isCompleted + + def flatMap[S]( + f: T => Interruptable[S] + )(implicit executor: ExecutionContext): Interruptable[S] = + new Interruptable(future.flatMap(f(_).future), cancelPromise) + + def map[S]( + f: T => S + )(implicit executor: ExecutionContext): Interruptable[S] = + new Interruptable(future.map(f(_)), cancelPromise) + + def recover[U >: T]( + pf: PartialFunction[Throwable, U] + )(implicit executor: ExecutionContext): Interruptable[U] = { + val pf0: PartialFunction[Throwable, U] = { case CancelConnectException => + throw CancelConnectException + } + new Interruptable(future.recover(pf0.orElse(pf)), cancelPromise) + } +} + +object Interruptable { + def successful[T](result: T) = + new Interruptable(Future.successful(result), Promise()) + + object CancelConnectException extends RuntimeException + implicit class XtensionFuture[+T](future: Future[T]) { + def withInterrupt(cancelPromise: Promise[Unit]): Interruptable[T] = + new Interruptable(future, cancelPromise) + } +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/scalacli/ScalaCliServers.scala b/metals/src/main/scala/scala/meta/internal/metals/scalacli/ScalaCliServers.scala index 9fd739544bd..268354f22a1 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/scalacli/ScalaCliServers.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/scalacli/ScalaCliServers.scala @@ -49,6 +49,9 @@ class ScalaCliServers( )(implicit ec: ExecutionContextExecutorService) extends Cancelable { + private val lastServerPaths = + new AtomicReference[Set[AbsolutePath]](Set.empty) + private def localTmpWorkspace(path: AbsolutePath) = { val root = if (path.isDirectory) path else path.parent root.resolve(s".metals-scala-cli/") @@ -128,6 +131,11 @@ class ScalaCliServers( def paths: Iterable[AbsolutePath] = servers.map(_.path) + def startForAllLastPaths(filter: AbsolutePath => Boolean): Future[Set[Unit]] = + Future.sequence( + lastServerPaths.getAndSet(Set.empty).withFilter(filter).map(start) + ) + def start(path: AbsolutePath): Future[Unit] = { val customWorkspace = if (path.isDirectory) None @@ -201,8 +209,11 @@ class ScalaCliServers( } yield () } - def stop(): Future[Unit] = { + def stop(storeLast: Boolean = false): Future[Unit] = { val servers = serversRef.getAndSet(Queue.empty) + if (storeLast) { + lastServerPaths.updateAndGet(_ ++ servers.map(_.path)) + } Future.sequence(servers.map(_.stop()).toSeq).ignoreValue } diff --git a/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala b/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala index 47e40e1ee0d..68ae1d14fc0 100644 --- a/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala +++ b/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala @@ -19,6 +19,7 @@ import scala.meta.io.AbsolutePath import com.google.gson.JsonObject import com.google.gson.JsonPrimitive +import org.eclipse.lsp4j.MessageActionItem import tests.BaseImportSuite import tests.JavaHomeChangeTest import tests.ScriptsAssertions @@ -211,6 +212,7 @@ class SbtBloopLspSuite _ = assertNoDiff( client.beginProgressMessages, List( + progressMessage, progressMessage, Messages.importingBuild, Messages.indexing, @@ -882,4 +884,38 @@ class SbtBloopLspSuite |""".stripMargin, ) + test("switch-build-server-while-connect") { + cleanWorkspace() + val layout = + s"""|/project/build.properties + |sbt.version=${V.sbtVersion} + |/build.sbt + |scalaVersion := "${V.scala213}" + |/src/main/scala/A.scala + | + |object A { + | val i: Int = "aaa" + |} + |""".stripMargin + writeLayout(layout) + client.importBuild = ImportBuild.yes + client.selectBspServer = { _ => new MessageActionItem("sbt") } + for { + _ <- server.initialize() + _ = server.initialized() + connectionProvider = server.headServer.connectionProvider.Connect + _ = while (connectionProvider.getOngoingRequest().isEmpty) { + // wait for connect to start + Thread.sleep(100) + } + bloopConnectF = connectionProvider.getOngoingRequest().get.promise.future + bspSwitchF = server.executeCommand(ServerCommands.BspSwitch) + _ <- bloopConnectF + _ = assert(!server.server.indexingPromise.isCompleted) + _ <- bspSwitchF + _ = assert(server.server.indexingPromise.isCompleted) + _ = assert(server.server.bspSession.exists(_.main.isSbt)) + } yield () + } + }