diff --git a/rules/private/phases/phase_test_launcher.bzl b/rules/private/phases/phase_test_launcher.bzl index 5f2c943d..80c06b07 100644 --- a/rules/private/phases/phase_test_launcher.bzl +++ b/rules/private/phases/phase_test_launcher.bzl @@ -36,6 +36,7 @@ def phase_test_launcher(ctx, g): args = ctx.actions.args() args.add("--parallel", ctx.attr.parallel) + args.add(ctx.attr.parallel_n, format = "--parallel-n=%s") args.add("--apis", g.compile.zinc_info.apis.short_path) args.add_all(ctx.attr.frameworks, format_each = "--framework=%s") if ctx.attr.isolation == "classloader": diff --git a/rules/scala.bzl b/rules/scala.bzl index c8daf6bb..e09c91f6 100644 --- a/rules/scala.bzl +++ b/rules/scala.bzl @@ -340,6 +340,7 @@ def make_scala_test(*extras): ), "runner": attr.label(default = "@rules_scala3//scala/workers/zinc/test"), "parallel": attr.bool(default = True), + "parallel_n": attr.int(default = 8), "subprocess_runner": attr.label(default = "@rules_scala3//scala/common/sbt-testing:subprocess"), }, _extras_attributes(extras), diff --git a/scala/common/sbt-testing/Test.scala b/scala/common/sbt-testing/Test.scala index 631a43d0..310e2472 100644 --- a/scala/common/sbt-testing/Test.scala +++ b/scala/common/sbt-testing/Test.scala @@ -7,7 +7,7 @@ import scala.util.control.NonFatal final case class TestDefinition(name: String, fingerprint: Fingerprint) -final class TestFrameworkLoader(loader: ClassLoader, logger: Logger): +final class TestFrameworkLoader(loader: ClassLoader): def load(className: String) = val framework = try Some(Class.forName(className, true, loader).getDeclaredConstructor().newInstance()) @@ -37,7 +37,7 @@ object TestHelper: ) final class TestReporter(logger: Logger): - def post(failures: Traversable[String]) = + def post(failures: Iterable[String]) = if failures.nonEmpty then logger.error(s"${failures.size} ${if failures.size == 1 then "failure" else "failures"}:") failures.toSeq.sorted.foreach(name => logger.error(s" $name")) @@ -45,7 +45,7 @@ final class TestReporter(logger: Logger): def postTask() = logger.info("") - def pre(framework: Framework, tasks: Traversable[Task]) = + def pre(framework: Framework, tasks: Iterable[Task]) = logger.info(s"${framework.getClass.getName}: ${tasks.size} tests") logger.info("") @@ -53,7 +53,7 @@ final class TestReporter(logger: Logger): final class TestTaskExecutor(logger: Logger): def execute(task: Task, failures: mutable.Set[String]): mutable.ListBuffer[Event] = - var events = mutable.ListBuffer[Event]() + val events = mutable.ListBuffer[Event]() val pending = mutable.Set.empty[String] def execute(task: Task): Unit = diff --git a/scala/workers/zinc/test/TestFrameworkRunner.scala b/scala/workers/zinc/test/TestFrameworkRunner.scala index a55de90b..10e4e9d4 100644 --- a/scala/workers/zinc/test/TestFrameworkRunner.scala +++ b/scala/workers/zinc/test/TestFrameworkRunner.scala @@ -15,7 +15,7 @@ import common.sbt_testing.* final case class FinishedTask(name: String, events: collection.Seq[Event], failures: collection.Set[String]) -final class BasicTestRunner(framework: Framework, classLoader: ClassLoader, parallel: Boolean, logger: Logger) extends TestFrameworkRunner: +final class BasicTestRunner(framework: Framework, classLoader: ClassLoader, parallel: Boolean, parallelN: Int, logger: Logger) extends TestFrameworkRunner: def execute(tests: Seq[TestDefinition], scopeAndTestName: String, arguments: Seq[String]) = ClassLoaders.withContextClassLoader(classLoader) { TestHelper.withRunner(framework, scopeAndTestName, classLoader, arguments) { runner => @@ -24,12 +24,12 @@ final class BasicTestRunner(framework: Framework, classLoader: ClassLoader, para reporter.pre(framework, tasks) given taskExecutor: TestTaskExecutor = TestTaskExecutor(logger) - val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel) + val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel, parallelN = parallelN) TestFrameworkRunner.report(tasksAndEvents, failures) } } -final class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () => ClassLoader, parallel: Boolean, logger: Logger) +final class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () => ClassLoader, parallel: Boolean, parallelN: Int, logger: Logger) extends TestFrameworkRunner: def execute(tests: Seq[TestDefinition], scopeAndTestName: String, arguments: Seq[String]) = given reporter: TestReporter = TestReporter(logger) @@ -48,11 +48,11 @@ final class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () for test <- tests do val classLoader = classLoaderProvider() - val isolatedFramework = TestFrameworkLoader(classLoader, logger).load(framework.getClass.getName).get + val isolatedFramework = TestFrameworkLoader(classLoader).load(framework.getClass.getName).get TestHelper.withRunner(isolatedFramework, scopeAndTestName, classLoader, arguments) { runner => ClassLoaders.withContextClassLoader(classLoader) { val tasks = runner.tasks(Array(TestHelper.taskDef(test, scopeAndTestName))) - val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel) + val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel, parallelN = parallelN) totalTasksAndEvents ++= tasksAndEvents totalFailures ++= failures } @@ -109,12 +109,15 @@ sealed trait TestFrameworkRunner: def execute(tests: Seq[TestDefinition], scopeAndTestName: String, arguments: Seq[String]): Boolean object TestFrameworkRunner: - def run(tasks: collection.Seq[Task], parallel: Boolean)( + def run(tasks: collection.Seq[Task], parallel: Boolean, parallelN: Int)( using taskExecutor: TestTaskExecutor, reporter: TestReporter ): (mutable.ListBuffer[(String, collection.Seq[Event])], collection.Set[String]) = val finishedTasks = - if parallel then Await.result(Future.sequence(tasks.map(t => Future(runTask(t)))), Duration.Inf) + if parallel then + val fut = Future.traverse(tasks.grouped(parallelN)): xs => + Future.sequence(xs.map(t => Future(runTask(t)))) + Await.result(fut, Duration.Inf).flatten else tasks.map(runTask(_)) val failures = mutable.Set.empty[String] diff --git a/scala/workers/zinc/test/TestRunner.scala b/scala/workers/zinc/test/TestRunner.scala index 5b260aec..b56ed8b2 100644 --- a/scala/workers/zinc/test/TestRunner.scala +++ b/scala/workers/zinc/test/TestRunner.scala @@ -58,6 +58,7 @@ object TestRunnerArguments: final case class TestWorkArguments( parallel: Boolean = false, + parallelN: Int = 1, apis: Path = Paths.get("."), subprocessExec: Path = Paths.get("."), isolation: Isolation = Isolation.None, @@ -71,6 +72,7 @@ object TestWorkArguments: private val parser = OParser.sequence( opt[Boolean]("parallel").optional().action((v, c) => c.copy(parallel = v)), + opt[Int]("parallel-n").optional().action((v, c) => c.copy(parallelN = Math.max(1, v))), opt[File]("apis").required().action((f, c) => c.copy(apis = f.toPath)).text("APIs file"), opt[File]("subprocess_exec").optional().action((f, c) => c.copy(subprocessExec = f.toPath)).text("Executable for SubprocessTestRunner"), opt[Isolation]("isolation").optional().action((iso, c) => c.copy(isolation = iso)).text("Test isolation"), @@ -133,7 +135,7 @@ object TestRunner: ProtobufReaders(ReadMapper.getEmptyMapper, Schema.Version.V1_1).fromApis(shouldStoreApis = true)(raw) catch case NonFatal(e) => throw Exception(s"Failed to load APIs from $apisFile", e) - val loader = TestFrameworkLoader(classLoader, logger) + val loader = TestFrameworkLoader(classLoader) val frameworks = workArgs.frameworks.flatMap(loader.load) val testClass = sys.env @@ -163,11 +165,11 @@ object TestRunner: case Isolation.ClassLoader => val urls = classpath.filterNot(sharedClasspath.toSet).map(_.toUri.toURL).toArray def classLoaderProvider() = URLClassLoader(urls, sharedClassLoader) - ClassLoaderTestRunner(framework, classLoaderProvider, parallel = workArgs.parallel, logger) + ClassLoaderTestRunner(framework, classLoaderProvider, parallel = workArgs.parallel, parallelN = workArgs.parallelN, logger) case Isolation.Process => val executable = runPath.resolve(workArgs.subprocessExec) ProcessTestRunner(framework, classpath, ProcessCommand(executable.toString, runArgs.subprocessArg), logger) - case Isolation.None => BasicTestRunner(framework, classLoader, parallel = workArgs.parallel, logger) + case Isolation.None => BasicTestRunner(framework, classLoader, parallel = workArgs.parallel, parallelN = workArgs.parallelN, logger) try runner.execute(filteredTests, testScopeAndName.getOrElse(""), runArgs.frameworkArgs) catch case e: Throwable =>