diff --git a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala index 5f42371bf38..306082a7573 100644 --- a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala +++ b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala @@ -19,6 +19,10 @@ import pekko.stream.{ Materializer, SystemMaterializer } import pekko.stream.impl.PhasedFusingActorMaterializer import pekko.stream.testkit.scaladsl +import java.time.Duration +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.FiniteDuration + object StreamTestKit { /** @@ -29,7 +33,21 @@ object StreamTestKit { def assertAllStagesStopped(mat: Materializer): Unit = mat match { case impl: PhasedFusingActorMaterializer => - scaladsl.StreamTestKit.assertNoChildren(impl.system, impl.supervisor) + scaladsl.StreamTestKit.assertNoChildren(impl.system, impl.supervisor, None) + case _ => + } + + /** + * Assert that there are no stages running under a given materializer. + * Usually this assertion is run after a test-case to check that all of the + * stages have terminated successfully with an overridden duration that ignores + * `stream.testkit.all-stages-stopped-timeout`. + */ + def assertAllStagesStopped(mat: Materializer, overrideTimeout: Duration): Unit = + mat match { + case impl: PhasedFusingActorMaterializer => + scaladsl.StreamTestKit.assertNoChildren(impl.system, impl.supervisor, + Some(FiniteDuration(overrideTimeout.toMillis, TimeUnit.MILLISECONDS))) case _ => } diff --git a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala index 85d5d5623b8..e5797f59634 100644 --- a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala +++ b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala @@ -35,12 +35,30 @@ object StreamTestKit { * This assertion is useful to check that all of the stages have * terminated successfully. */ + def assertAllStagesStopped[T](block: => T, overrideTimeout: FiniteDuration)(implicit materializer: Materializer): T = + materializer match { + case impl: PhasedFusingActorMaterializer => + stopAllChildren(impl.system, impl.supervisor) + val result = block + assertNoChildren(impl.system, impl.supervisor, Some(overrideTimeout)) + result + case _ => block + } + + /** + * Asserts that after the given code block is ran, no stages are left over + * that were created by the given materializer with an overridden duration + * that ignores `stream.testkit.all-stages-stopped-timeout`. + * + * This assertion is useful to check that all of the stages have + * terminated successfully. + */ def assertAllStagesStopped[T](block: => T)(implicit materializer: Materializer): T = materializer match { case impl: PhasedFusingActorMaterializer => stopAllChildren(impl.system, impl.supervisor) val result = block - assertNoChildren(impl.system, impl.supervisor) + assertNoChildren(impl.system, impl.supervisor, None) result case _ => block } @@ -53,10 +71,16 @@ object StreamTestKit { } /** INTERNAL API */ - @InternalApi private[testkit] def assertNoChildren(sys: ActorSystem, supervisor: ActorRef): Unit = { + @InternalApi private[testkit] def assertNoChildren(sys: ActorSystem, supervisor: ActorRef, + overrideTimeout: Option[FiniteDuration]): Unit = { val probe = TestProbe()(sys) - val c = sys.settings.config.getConfig("pekko.stream.testkit") - val timeout = c.getDuration("all-stages-stopped-timeout", MILLISECONDS).millis + val timeout = overrideTimeout match { + case Some(value) => value + case None => + val c = sys.settings.config.getConfig("pekko.stream.testkit") + c.getDuration("all-stages-stopped-timeout", MILLISECONDS).millis + } + probe.within(timeout) { try probe.awaitAssert { supervisor.tell(StreamSupervisor.GetChildren, probe.ref) diff --git a/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala b/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala index 827a3c4918c..501da0ef2af 100644 --- a/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala +++ b/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala @@ -31,6 +31,7 @@ import org.scalatest.Failed import com.typesafe.config.{ Config, ConfigFactory } abstract class StreamSpec(_system: ActorSystem) extends PekkoSpec(_system) { + private var _allStagesStoppedTimeout: Option[FiniteDuration] = None def this(config: Config) = this( ActorSystem( @@ -43,6 +44,31 @@ abstract class StreamSpec(_system: ActorSystem) extends PekkoSpec(_system) { def this() = this(ActorSystem(TestKitUtils.testNameFromCallStack(classOf[StreamSpec], "".r), PekkoSpec.testConf)) + def this(config: Config, overrideAllStagesStoppedTimeout: FiniteDuration) = { + this(config) + _allStagesStoppedTimeout = Some(overrideAllStagesStoppedTimeout) + } + + def this(s: String, overrideAllStagesStoppedTimeout: FiniteDuration) = { + this(s) + _allStagesStoppedTimeout = Some(overrideAllStagesStoppedTimeout) + } + + def this(configMap: Map[String, _], overrideAllStagesStoppedTimeout: FiniteDuration) = { + this(configMap) + _allStagesStoppedTimeout = Some(overrideAllStagesStoppedTimeout) + } + + def this(overrideAllStagesStoppedTimeout: FiniteDuration) = { + this() + _allStagesStoppedTimeout = Some(overrideAllStagesStoppedTimeout) + } + + def this(_system: ActorSystem, overrideAllStagesStoppedTimeout: FiniteDuration) = { + this(_system) + _allStagesStoppedTimeout = Some(overrideAllStagesStoppedTimeout) + } + override def withFixture(test: NoArgTest) = { super.withFixture(test) match { case failed: Failed => @@ -73,7 +99,7 @@ abstract class StreamSpec(_system: ActorSystem) extends PekkoSpec(_system) { case impl: PhasedFusingActorMaterializer => stopAllChildren(impl.system, impl.supervisor) val result = test.apply() - assertNoChildren(impl.system, impl.supervisor) + assertNoChildren(impl.system, impl.supervisor, _allStagesStoppedTimeout) result case _ => other }