diff --git a/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala b/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala index 8c9c07ae38c..e5325964b49 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala @@ -61,6 +61,8 @@ class BuildServerConnection private ( extends Cancelable { private val defaultMinTimeout = FiniteDuration(3, TimeUnit.MINUTES) + private def defaultTimeout(name: String) = + Some(Timeout.default(name, defaultMinTimeout)) @volatile private var connection = Future.successful(initialConnection) initialConnection.setReconnect(() => reconnect().ignoreValue) @@ -73,7 +75,6 @@ class BuildServerConnection private ( val requestRegistry = new RequestRegistry( - defaultMinTimeout, initialConnection.cancelables, languageClient, Some(requestTimeOutNotification), @@ -182,7 +183,7 @@ class BuildServerConnection private ( def compile( params: CompileParams, - timeout: Timeout, + timeout: Option[Timeout], ): CompletableFuture[CompileResult] = { register( server => server.buildTargetCompile(params), @@ -225,6 +226,7 @@ class BuildServerConnection private ( register( server => server.buildTargetScalaMainClasses(params), onFail, + defaultTimeout("main classes"), ).asScala } else Future.successful(resultOnUnsupported) @@ -244,6 +246,7 @@ class BuildServerConnection private ( register( server => server.buildTargetScalaTestClasses(params), onFail, + defaultTimeout("test classes"), ).asScala } else Future.successful(resultOnUnsupported) } @@ -429,7 +432,7 @@ class BuildServerConnection private ( private def register[T: ClassTag]( action: MetalsBuildServer => CompletableFuture[T], onFail: => Option[(T, String)] = None, - timeout: Timeout = Timeout.NoTimeout, + timeout: Option[Timeout] = None, ): CompletableFuture[T] = { val localCancelable = new MutableCancelable() def runWithCanceling( diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala index b8647eada15..49f7e2fdcd1 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilations.scala @@ -30,17 +30,15 @@ final class Compilations( compileWorksheets: Seq[AbsolutePath] => Future[Unit], onStartCompilation: () => Unit, )(implicit ec: ExecutionContext) { - - private val cascadeCompileTimeout = Timeout.NoTimeout - private val compileTimeout: Timeout.FlexTimeout = - Timeout.FlexTimeout("compile", Duration(10, TimeUnit.MINUTES)) + private val compileTimeout: Timeout = + Timeout("compile", Duration(10, TimeUnit.MINUTES)) // we are maintaining a separate queue for cascade compilation since those must happen ASAP private val compileBatch = new BatchedFunction[ b.BuildTargetIdentifier, Map[BuildTargetIdentifier, b.CompileResult], ]( - compile(compileTimeout), + compile(timeout = Some(compileTimeout)), "compileBatch", shouldLogQueue = true, Some(Map.empty), @@ -50,7 +48,7 @@ final class Compilations( b.BuildTargetIdentifier, Map[BuildTargetIdentifier, b.CompileResult], ]( - compile(cascadeCompileTimeout), + compile(timeout = None), "cascadeBatch", shouldLogQueue = true, Some(Map.empty), @@ -166,7 +164,7 @@ final class Compilations( for { cleanResult <- cleaned if cleanResult.getCleaned() == true - _ <- compile(cascadeCompileTimeout)(targetIds).future + _ <- compile(timeout = None)(targetIds).future } yield () } @@ -204,7 +202,7 @@ final class Compilations( Future.sequence(expansions).map(_.flatten) } - private def compile(timeout: Timeout)( + private def compile(timeout: Option[Timeout])( targets: Seq[b.BuildTargetIdentifier] ): CancelableFuture[Map[BuildTargetIdentifier, b.CompileResult]] = { @@ -243,7 +241,7 @@ final class Compilations( private def compile( connection: BuildServerConnection, targets: Seq[b.BuildTargetIdentifier], - timeout: Timeout, + timeout: Option[Timeout], ): CancelableFuture[b.CompileResult] = { val params = new b.CompileParams(targets.asJava) targets.foreach(target => isCompiling(target) = true) diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala index db0306eccdf..31891f03f3c 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala @@ -233,7 +233,9 @@ object MetalsEnrichments ec: ExecutionContext ): Future[A] = withTimeout(FiniteDuration(length, unit)) - def withTimeout(duration: FiniteDuration)(implicit ec: ExecutionContext): Future[A] = { + def withTimeout( + duration: FiniteDuration + )(implicit ec: ExecutionContext): Future[A] = { Future(Await.result(future, duration)) } diff --git a/metals/src/main/scala/scala/meta/internal/metals/utils/RequestRegistry.scala b/metals/src/main/scala/scala/meta/internal/metals/utils/RequestRegistry.scala index 65fb31507f2..32b1b80a13b 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/utils/RequestRegistry.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/utils/RequestRegistry.scala @@ -7,7 +7,6 @@ import java.util.concurrent.TimeoutException import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.duration.Duration -import scala.concurrent.duration.FiniteDuration import scala.util.Failure import scala.util.Success import scala.util.Try @@ -22,7 +21,6 @@ import scala.meta.internal.metals.MutableCancelable import org.eclipse.lsp4j.services.LanguageClient class RequestRegistry( - defaultMinTimeout: FiniteDuration, initialCancellables: List[Cancelable], languageClient: LanguageClient, requestTimeOutNotification: Option[DismissedNotifications#Notification] = @@ -30,7 +28,7 @@ class RequestRegistry( )(implicit ex: ExecutionContext ) { - private val timeouts: Timeouts = new Timeouts(defaultMinTimeout) + private val timeouts: Timeouts = new Timeouts() private val ongoingRequests = new MutableCancelable().addAll(initialCancellables) @@ -54,13 +52,14 @@ class RequestRegistry( def register[T]( action: () => CompletableFuture[T], - timeout: Timeout, + timeout: Option[Timeout], ): CancelableFuture[T] = { val CancelableFuture(result, cancelable) = - timeouts.getNameAndTimeout(timeout) match { - case Some((actionName, timeoutValue)) + timeout match { + case Some(timeout) if !requestTimeOutNotification.exists(_.isDismissed) => - FutureWithTimeout(timeoutValue, onTimeout(actionName)(_))(action) + val timeoutValue = timeouts.getTimeout(timeout) + FutureWithTimeout(timeoutValue, onTimeout(timeout.name)(_))(action) .transform { case Success((res, time)) => timeouts.measured(timeout, time) @@ -92,7 +91,6 @@ class RequestRegistry( ongoingRequests.cancel() } - def getTimeout(timeout: Timeout): Option[Duration] = - timeouts.getTimeout(timeout) + def getTimeout(timeout: Timeout): Duration = timeouts.getTimeout(timeout) } diff --git a/metals/src/main/scala/scala/meta/internal/metals/utils/Timeouts.scala b/metals/src/main/scala/scala/meta/internal/metals/utils/Timeouts.scala index a493f474661..8d037c222a5 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/utils/Timeouts.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/utils/Timeouts.scala @@ -5,53 +5,34 @@ import java.util.concurrent.atomic.AtomicReference import scala.concurrent.duration.FiniteDuration -sealed trait Timeout +case class Timeout(id: String, name: String, minTimeout: FiniteDuration) object Timeout { - case object NoTimeout extends Timeout - case class DefaultFlexTimeout(id: String) extends Timeout - case class FlexTimeout(id: String, minTimeout: FiniteDuration) extends Timeout + def apply(name: String, minTimeout: FiniteDuration): Timeout = + Timeout(name, name, minTimeout) + def default(name: String, minTimeout: FiniteDuration): Timeout = + Timeout("default", name, minTimeout) } -class Timeouts(defaultMinTimeout: FiniteDuration) { - private val defaultFlexTimeout: AtomicReference[Option[AvgTime]] = - new AtomicReference(None) +class Timeouts() { private val timeouts: AtomicReference[Map[String, AvgTime]] = new AtomicReference(Map()) + def measured(timeout: Timeout, time: FiniteDuration): Any = { val addToOption: Option[AvgTime] => Option[AvgTime] = { case Some(avgTime) => Some(avgTime.add(time)) case None => Some(AvgTime.of(time)) } - timeout match { - case Timeout.DefaultFlexTimeout(_) => - defaultFlexTimeout.getAndUpdate(addToOption(_)) - case Timeout.FlexTimeout(id, _) => - timeouts.getAndUpdate(_.updatedWith(id)(addToOption)) - case _ => - } + timeouts.getAndUpdate(_.updatedWith(timeout.id)(addToOption)) } - def getNameAndTimeout(timeout: Timeout): Option[(String, FiniteDuration)] = { - timeout match { - case Timeout.DefaultFlexTimeout(id) => - Some( - defaultFlexTimeout.get - .map(_.avgWithMin(defaultMinTimeout)) - .getOrElse(defaultMinTimeout) - ).map((id, _)) - case Timeout.FlexTimeout(id, minTimeout) => - Some( - timeouts.get - .get(id) - .map(_.avgWithMin(minTimeout)) - .getOrElse(minTimeout) - ).map((id, _)) - case Timeout.NoTimeout => None - } + def getTimeout(timeout: Timeout): FiniteDuration = { + val Timeout(id, _, minTimeout) = timeout + timeouts.get + .get(id) + .map(_.avgWithMin(minTimeout)) + .getOrElse(minTimeout) } - def getTimeout(timeout: Timeout): Option[FiniteDuration] = - getNameAndTimeout(timeout).map(_._2) } case class AvgTime(samples: Int, totalTime: Long) { diff --git a/tests/unit/src/test/scala/tests/RequestRegistrySuite.scala b/tests/unit/src/test/scala/tests/RequestRegistrySuite.scala index 336b80a7188..7b947330eba 100644 --- a/tests/unit/src/test/scala/tests/RequestRegistrySuite.scala +++ b/tests/unit/src/test/scala/tests/RequestRegistrySuite.scala @@ -26,14 +26,15 @@ class RequestRegistrySuite extends FunSuite { implicit val ex: ExecutionContextExecutorService = ExecutionContext.fromExecutorService(Executors.newCachedThreadPool()) val duration: FiniteDuration = FiniteDuration(1, TimeUnit.SECONDS) - def createRegistry() = - new RequestRegistry(duration, List(), CancellingLanguageClient) - val defaultTimeout: Timeout.DefaultFlexTimeout = - Timeout.DefaultFlexTimeout("request") + def createRegistry( + onTimeout: () => MessageActionItem = () => Messages.RequestTimeout.cancel + ) = + new RequestRegistry(List(), new RequestRegistrySuite.Client(onTimeout)) + val defaultTimeout: Some[Timeout] = Some(Timeout.default("request", duration)) test("avg") { val requestRegistry = createRegistry() - val doneTimeout = Timeout.FlexTimeout("done", duration) + val doneTimeout = Some(Timeout("done", duration)) for { _ <- requestRegistry .register( @@ -57,14 +58,13 @@ class RequestRegistrySuite extends FunSuite { _ <- requestRegistry.register(ExampleFutures.fast, defaultTimeout).future _ <- requestRegistry.register(ExampleFutures.fast, defaultTimeout).future _ = assert( - requestRegistry.getTimeout(defaultTimeout).get > duration + requestRegistry.getTimeout(defaultTimeout.get) > duration ) _ = assert( requestRegistry - .getTimeout(defaultTimeout) - .get < duration * 2 + .getTimeout(defaultTimeout.get) < duration * 2 ) - _ = assert(requestRegistry.getTimeout(doneTimeout).get == duration) + _ = assert(requestRegistry.getTimeout(doneTimeout.get) == duration) } yield () } @@ -78,7 +78,7 @@ class RequestRegistrySuite extends FunSuite { .failed _ = assert(err.isInstanceOf[TimeoutException]) _ = assertEquals( - requestRegistry.getTimeout(defaultTimeout).get, + requestRegistry.getTimeout(defaultTimeout.get), duration * 3, ) _ <- requestRegistry.register(ExampleFutures.fast, defaultTimeout).future @@ -92,11 +92,11 @@ class RequestRegistrySuite extends FunSuite { val requestRegistry = createRegistry() val f1 = requestRegistry.register( ExampleFutures.infinite(promise1), - Timeout.NoTimeout + timeout = None, ) val f2 = requestRegistry.register( ExampleFutures.infinite(promise2), - Timeout.NoTimeout + timeout = None, ) requestRegistry.cancel() for { @@ -109,15 +109,36 @@ class RequestRegistrySuite extends FunSuite { } yield () } -} + test("wait-on-timeout") { + val promise1 = Promise[Unit]() + var timeoutCount = 0 + val requestRegistry = createRegistry(() => { + timeoutCount += 1 + if (timeoutCount <= 1) Messages.RequestTimeout.waitAction + else Messages.RequestTimeout.cancel + }) + + for { + err <- requestRegistry + .register( + ExampleFutures.infinite(promise1), + defaultTimeout, + ) + .future + .failed + _ = assert(err.isInstanceOf[TimeoutException]) + _ = assert(timeoutCount == 2) + _ <- promise1.future + } yield () + } -object RequestType extends Enumeration { - val Type1, Type2 = Value } -object CancellingLanguageClient extends NoopLanguageClient { - override def showMessageRequest( - requestParams: ShowMessageRequestParams - ): CompletableFuture[MessageActionItem] = - Future.successful(Messages.RequestTimeout.cancel).asJava +object RequestRegistrySuite { + class Client(onTimeout: () => MessageActionItem) extends NoopLanguageClient { + override def showMessageRequest( + requestParams: ShowMessageRequestParams + ): CompletableFuture[MessageActionItem] = + Future.successful(onTimeout()).asJava + } } diff --git a/tests/unit/src/test/scala/tests/TimeoutSuite.scala b/tests/unit/src/test/scala/tests/TimeoutSuite.scala index 0deb4f08aaf..a9c50a509a0 100644 --- a/tests/unit/src/test/scala/tests/TimeoutSuite.scala +++ b/tests/unit/src/test/scala/tests/TimeoutSuite.scala @@ -55,25 +55,25 @@ class TimeoutSuite extends FunSuite { test("timeouts") { def min(int: Int) = Duration(int, TimeUnit.MINUTES) - val timeouts = new Timeouts(min(3)) - val flexTimeout = Timeout.FlexTimeout("flex", min(6)) - assertEquals(timeouts.getTimeout(Timeout.NoTimeout), None) + val defaultTime = Timeout.default("request", min(3)) + val timeouts = new Timeouts() + val flexTimeout = Timeout("flex", min(6)) assertEquals( - timeouts.getTimeout(Timeout.DefaultFlexTimeout("request")), - Some(min(3)), + timeouts.getTimeout(defaultTime), + min(3), ) - assertEquals(timeouts.getTimeout(flexTimeout), Some(min(6))) - timeouts.measured(Timeout.DefaultFlexTimeout("request"), min(1)) - timeouts.measured(Timeout.DefaultFlexTimeout("request"), min(3)) + assertEquals(timeouts.getTimeout(flexTimeout), min(6)) + timeouts.measured(defaultTime, min(1)) + timeouts.measured(defaultTime, min(3)) timeouts.measured(flexTimeout, min(1)) timeouts.measured(flexTimeout, min(2)) // avg * 3 > min assertEquals( - timeouts.getTimeout(Timeout.DefaultFlexTimeout("request")), - Some(min(6)), + timeouts.getTimeout(defaultTime), + min(6), ) // avg * 3 < min - assertEquals(timeouts.getTimeout(flexTimeout), Some(min(6))) + assertEquals(timeouts.getTimeout(flexTimeout), min(6)) } }