Skip to content

Commit

Permalink
add test + simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Nov 28, 2023
1 parent 071c3aa commit 7c0251d
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class BuildServerConnection private (
extends Cancelable {

private val defaultMinTimeout = FiniteDuration(3, TimeUnit.MINUTES)
private def defaultTimeout(name: String) =
Some(Timeout.default(name, defaultMinTimeout))

@volatile private var connection = Future.successful(initialConnection)
initialConnection.setReconnect(() => reconnect().ignoreValue)
Expand All @@ -73,7 +75,6 @@ class BuildServerConnection private (

val requestRegistry =
new RequestRegistry(
defaultMinTimeout,
initialConnection.cancelables,
languageClient,
Some(requestTimeOutNotification),
Expand Down Expand Up @@ -182,7 +183,7 @@ class BuildServerConnection private (

def compile(
params: CompileParams,
timeout: Timeout,
timeout: Option[Timeout],
): CompletableFuture[CompileResult] = {
register(
server => server.buildTargetCompile(params),
Expand Down Expand Up @@ -225,6 +226,7 @@ class BuildServerConnection private (
register(
server => server.buildTargetScalaMainClasses(params),
onFail,
defaultTimeout("main classes"),
).asScala
} else Future.successful(resultOnUnsupported)

Expand All @@ -244,6 +246,7 @@ class BuildServerConnection private (
register(
server => server.buildTargetScalaTestClasses(params),
onFail,
defaultTimeout("test classes"),
).asScala
} else Future.successful(resultOnUnsupported)
}
Expand Down Expand Up @@ -429,7 +432,7 @@ class BuildServerConnection private (
private def register[T: ClassTag](
action: MetalsBuildServer => CompletableFuture[T],
onFail: => Option[(T, String)] = None,
timeout: Timeout = Timeout.NoTimeout,
timeout: Option[Timeout] = None,
): CompletableFuture[T] = {
val localCancelable = new MutableCancelable()
def runWithCanceling(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@ final class Compilations(
compileWorksheets: Seq[AbsolutePath] => Future[Unit],
onStartCompilation: () => Unit,
)(implicit ec: ExecutionContext) {

private val cascadeCompileTimeout = Timeout.NoTimeout
private val compileTimeout: Timeout.FlexTimeout =
Timeout.FlexTimeout("compile", Duration(10, TimeUnit.MINUTES))
private val compileTimeout: Timeout =
Timeout("compile", Duration(10, TimeUnit.MINUTES))
// we are maintaining a separate queue for cascade compilation since those must happen ASAP
private val compileBatch =
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](
compile(compileTimeout),
compile(timeout = Some(compileTimeout)),
"compileBatch",
shouldLogQueue = true,
Some(Map.empty),
Expand All @@ -50,7 +48,7 @@ final class Compilations(
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](
compile(cascadeCompileTimeout),
compile(timeout = None),
"cascadeBatch",
shouldLogQueue = true,
Some(Map.empty),
Expand Down Expand Up @@ -166,7 +164,7 @@ final class Compilations(
for {
cleanResult <- cleaned
if cleanResult.getCleaned() == true
_ <- compile(cascadeCompileTimeout)(targetIds).future
_ <- compile(timeout = None)(targetIds).future
} yield ()
}

Expand Down Expand Up @@ -204,7 +202,7 @@ final class Compilations(
Future.sequence(expansions).map(_.flatten)
}

private def compile(timeout: Timeout)(
private def compile(timeout: Option[Timeout])(
targets: Seq[b.BuildTargetIdentifier]
): CancelableFuture[Map[BuildTargetIdentifier, b.CompileResult]] = {

Expand Down Expand Up @@ -243,7 +241,7 @@ final class Compilations(
private def compile(
connection: BuildServerConnection,
targets: Seq[b.BuildTargetIdentifier],
timeout: Timeout,
timeout: Option[Timeout],
): CancelableFuture[b.CompileResult] = {
val params = new b.CompileParams(targets.asJava)
targets.foreach(target => isCompiling(target) = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ object MetalsEnrichments
ec: ExecutionContext
): Future[A] = withTimeout(FiniteDuration(length, unit))

def withTimeout(duration: FiniteDuration)(implicit ec: ExecutionContext): Future[A] = {
def withTimeout(
duration: FiniteDuration
)(implicit ec: ExecutionContext): Future[A] = {
Future(Await.result(future, duration))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import java.util.concurrent.TimeoutException
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration
import scala.util.Failure
import scala.util.Success
import scala.util.Try
Expand All @@ -22,15 +21,14 @@ import scala.meta.internal.metals.MutableCancelable
import org.eclipse.lsp4j.services.LanguageClient

class RequestRegistry(
defaultMinTimeout: FiniteDuration,
initialCancellables: List[Cancelable],
languageClient: LanguageClient,
requestTimeOutNotification: Option[DismissedNotifications#Notification] =
None,
)(implicit
ex: ExecutionContext
) {
private val timeouts: Timeouts = new Timeouts(defaultMinTimeout)
private val timeouts: Timeouts = new Timeouts()
private val ongoingRequests =
new MutableCancelable().addAll(initialCancellables)

Expand All @@ -54,13 +52,14 @@ class RequestRegistry(

def register[T](
action: () => CompletableFuture[T],
timeout: Timeout,
timeout: Option[Timeout],
): CancelableFuture[T] = {
val CancelableFuture(result, cancelable) =
timeouts.getNameAndTimeout(timeout) match {
case Some((actionName, timeoutValue))
timeout match {
case Some(timeout)
if !requestTimeOutNotification.exists(_.isDismissed) =>
FutureWithTimeout(timeoutValue, onTimeout(actionName)(_))(action)
val timeoutValue = timeouts.getTimeout(timeout)
FutureWithTimeout(timeoutValue, onTimeout(timeout.name)(_))(action)
.transform {
case Success((res, time)) =>
timeouts.measured(timeout, time)
Expand Down Expand Up @@ -92,7 +91,6 @@ class RequestRegistry(
ongoingRequests.cancel()
}

def getTimeout(timeout: Timeout): Option[Duration] =
timeouts.getTimeout(timeout)
def getTimeout(timeout: Timeout): Duration = timeouts.getTimeout(timeout)

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,34 @@ import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.duration.FiniteDuration

sealed trait Timeout
case class Timeout(id: String, name: String, minTimeout: FiniteDuration)
object Timeout {
case object NoTimeout extends Timeout
case class DefaultFlexTimeout(id: String) extends Timeout
case class FlexTimeout(id: String, minTimeout: FiniteDuration) extends Timeout
def apply(name: String, minTimeout: FiniteDuration): Timeout =
Timeout(name, name, minTimeout)
def default(name: String, minTimeout: FiniteDuration): Timeout =
Timeout("default", name, minTimeout)
}

class Timeouts(defaultMinTimeout: FiniteDuration) {
private val defaultFlexTimeout: AtomicReference[Option[AvgTime]] =
new AtomicReference(None)
class Timeouts() {
private val timeouts: AtomicReference[Map[String, AvgTime]] =
new AtomicReference(Map())

def measured(timeout: Timeout, time: FiniteDuration): Any = {
val addToOption: Option[AvgTime] => Option[AvgTime] = {
case Some(avgTime) => Some(avgTime.add(time))
case None => Some(AvgTime.of(time))
}
timeout match {
case Timeout.DefaultFlexTimeout(_) =>
defaultFlexTimeout.getAndUpdate(addToOption(_))
case Timeout.FlexTimeout(id, _) =>
timeouts.getAndUpdate(_.updatedWith(id)(addToOption))
case _ =>
}
timeouts.getAndUpdate(_.updatedWith(timeout.id)(addToOption))
}

def getNameAndTimeout(timeout: Timeout): Option[(String, FiniteDuration)] = {
timeout match {
case Timeout.DefaultFlexTimeout(id) =>
Some(
defaultFlexTimeout.get
.map(_.avgWithMin(defaultMinTimeout))
.getOrElse(defaultMinTimeout)
).map((id, _))
case Timeout.FlexTimeout(id, minTimeout) =>
Some(
timeouts.get
.get(id)
.map(_.avgWithMin(minTimeout))
.getOrElse(minTimeout)
).map((id, _))
case Timeout.NoTimeout => None
}
def getTimeout(timeout: Timeout): FiniteDuration = {
val Timeout(id, _, minTimeout) = timeout
timeouts.get
.get(id)
.map(_.avgWithMin(minTimeout))
.getOrElse(minTimeout)
}

def getTimeout(timeout: Timeout): Option[FiniteDuration] =
getNameAndTimeout(timeout).map(_._2)
}

case class AvgTime(samples: Int, totalTime: Long) {
Expand Down
61 changes: 41 additions & 20 deletions tests/unit/src/test/scala/tests/RequestRegistrySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ class RequestRegistrySuite extends FunSuite {
implicit val ex: ExecutionContextExecutorService =
ExecutionContext.fromExecutorService(Executors.newCachedThreadPool())
val duration: FiniteDuration = FiniteDuration(1, TimeUnit.SECONDS)
def createRegistry() =
new RequestRegistry(duration, List(), CancellingLanguageClient)
val defaultTimeout: Timeout.DefaultFlexTimeout =
Timeout.DefaultFlexTimeout("request")
def createRegistry(
onTimeout: () => MessageActionItem = () => Messages.RequestTimeout.cancel
) =
new RequestRegistry(List(), new RequestRegistrySuite.Client(onTimeout))
val defaultTimeout: Some[Timeout] = Some(Timeout.default("request", duration))

test("avg") {
val requestRegistry = createRegistry()
val doneTimeout = Timeout.FlexTimeout("done", duration)
val doneTimeout = Some(Timeout("done", duration))
for {
_ <- requestRegistry
.register(
Expand All @@ -57,14 +58,13 @@ class RequestRegistrySuite extends FunSuite {
_ <- requestRegistry.register(ExampleFutures.fast, defaultTimeout).future
_ <- requestRegistry.register(ExampleFutures.fast, defaultTimeout).future
_ = assert(
requestRegistry.getTimeout(defaultTimeout).get > duration
requestRegistry.getTimeout(defaultTimeout.get) > duration
)
_ = assert(
requestRegistry
.getTimeout(defaultTimeout)
.get < duration * 2
.getTimeout(defaultTimeout.get) < duration * 2
)
_ = assert(requestRegistry.getTimeout(doneTimeout).get == duration)
_ = assert(requestRegistry.getTimeout(doneTimeout.get) == duration)
} yield ()
}

Expand All @@ -78,7 +78,7 @@ class RequestRegistrySuite extends FunSuite {
.failed
_ = assert(err.isInstanceOf[TimeoutException])
_ = assertEquals(
requestRegistry.getTimeout(defaultTimeout).get,
requestRegistry.getTimeout(defaultTimeout.get),
duration * 3,
)
_ <- requestRegistry.register(ExampleFutures.fast, defaultTimeout).future
Expand All @@ -92,11 +92,11 @@ class RequestRegistrySuite extends FunSuite {
val requestRegistry = createRegistry()
val f1 = requestRegistry.register(
ExampleFutures.infinite(promise1),
Timeout.NoTimeout
timeout = None,
)
val f2 = requestRegistry.register(
ExampleFutures.infinite(promise2),
Timeout.NoTimeout
timeout = None,
)
requestRegistry.cancel()
for {
Expand All @@ -109,15 +109,36 @@ class RequestRegistrySuite extends FunSuite {
} yield ()
}

}
test("wait-on-timeout") {
val promise1 = Promise[Unit]()
var timeoutCount = 0
val requestRegistry = createRegistry(() => {
timeoutCount += 1
if (timeoutCount <= 1) Messages.RequestTimeout.waitAction
else Messages.RequestTimeout.cancel
})

for {
err <- requestRegistry
.register(
ExampleFutures.infinite(promise1),
defaultTimeout,
)
.future
.failed
_ = assert(err.isInstanceOf[TimeoutException])
_ = assert(timeoutCount == 2)
_ <- promise1.future
} yield ()
}

object RequestType extends Enumeration {
val Type1, Type2 = Value
}

object CancellingLanguageClient extends NoopLanguageClient {
override def showMessageRequest(
requestParams: ShowMessageRequestParams
): CompletableFuture[MessageActionItem] =
Future.successful(Messages.RequestTimeout.cancel).asJava
object RequestRegistrySuite {
class Client(onTimeout: () => MessageActionItem) extends NoopLanguageClient {
override def showMessageRequest(
requestParams: ShowMessageRequestParams
): CompletableFuture[MessageActionItem] =
Future.successful(onTimeout()).asJava
}
}
22 changes: 11 additions & 11 deletions tests/unit/src/test/scala/tests/TimeoutSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,25 @@ class TimeoutSuite extends FunSuite {

test("timeouts") {
def min(int: Int) = Duration(int, TimeUnit.MINUTES)
val timeouts = new Timeouts(min(3))
val flexTimeout = Timeout.FlexTimeout("flex", min(6))
assertEquals(timeouts.getTimeout(Timeout.NoTimeout), None)
val defaultTime = Timeout.default("request", min(3))
val timeouts = new Timeouts()
val flexTimeout = Timeout("flex", min(6))
assertEquals(
timeouts.getTimeout(Timeout.DefaultFlexTimeout("request")),
Some(min(3)),
timeouts.getTimeout(defaultTime),
min(3),
)
assertEquals(timeouts.getTimeout(flexTimeout), Some(min(6)))
timeouts.measured(Timeout.DefaultFlexTimeout("request"), min(1))
timeouts.measured(Timeout.DefaultFlexTimeout("request"), min(3))
assertEquals(timeouts.getTimeout(flexTimeout), min(6))
timeouts.measured(defaultTime, min(1))
timeouts.measured(defaultTime, min(3))
timeouts.measured(flexTimeout, min(1))
timeouts.measured(flexTimeout, min(2))
// avg * 3 > min
assertEquals(
timeouts.getTimeout(Timeout.DefaultFlexTimeout("request")),
Some(min(6)),
timeouts.getTimeout(defaultTime),
min(6),
)
// avg * 3 < min
assertEquals(timeouts.getTimeout(flexTimeout), Some(min(6)))
assertEquals(timeouts.getTimeout(flexTimeout), min(6))
}
}

Expand Down

0 comments on commit 7c0251d

Please sign in to comment.