Skip to content

Commit

Permalink
feat: ability to set number of parallel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timothyklim committed Mar 29, 2024
1 parent db790f3 commit ea1040c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 14 deletions.
1 change: 1 addition & 0 deletions rules/private/phases/phase_test_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def phase_test_launcher(ctx, g):

args = ctx.actions.args()
args.add("--parallel", ctx.attr.parallel)
args.add(ctx.attr.parallel_n, format = "--parallel-n=%s")
args.add("--apis", g.compile.zinc_info.apis.short_path)
args.add_all(ctx.attr.frameworks, format_each = "--framework=%s")
if ctx.attr.isolation == "classloader":
Expand Down
1 change: 1 addition & 0 deletions rules/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def make_scala_test(*extras):
),
"runner": attr.label(default = "@rules_scala3//scala/workers/zinc/test"),
"parallel": attr.bool(default = True),
"parallel_n": attr.int(default = 8),
"subprocess_runner": attr.label(default = "@rules_scala3//scala/common/sbt-testing:subprocess"),
},
_extras_attributes(extras),
Expand Down
8 changes: 4 additions & 4 deletions scala/common/sbt-testing/Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import scala.util.control.NonFatal

final case class TestDefinition(name: String, fingerprint: Fingerprint)

final class TestFrameworkLoader(loader: ClassLoader, logger: Logger):
final class TestFrameworkLoader(loader: ClassLoader):
def load(className: String) =
val framework =
try Some(Class.forName(className, true, loader).getDeclaredConstructor().newInstance())
Expand Down Expand Up @@ -37,23 +37,23 @@ object TestHelper:
)

final class TestReporter(logger: Logger):
def post(failures: Traversable[String]) =
def post(failures: Iterable[String]) =
if failures.nonEmpty then
logger.error(s"${failures.size} ${if failures.size == 1 then "failure" else "failures"}:")
failures.toSeq.sorted.foreach(name => logger.error(s" $name"))
logger.error("")

def postTask() = logger.info("")

def pre(framework: Framework, tasks: Traversable[Task]) =
def pre(framework: Framework, tasks: Iterable[Task]) =
logger.info(s"${framework.getClass.getName}: ${tasks.size} tests")
logger.info("")

def preTask(task: Task) = logger.info(task.taskDef.fullyQualifiedName)

final class TestTaskExecutor(logger: Logger):
def execute(task: Task, failures: mutable.Set[String]): mutable.ListBuffer[Event] =
var events = mutable.ListBuffer[Event]()
val events = mutable.ListBuffer[Event]()
val pending = mutable.Set.empty[String]

def execute(task: Task): Unit =
Expand Down
17 changes: 10 additions & 7 deletions scala/workers/zinc/test/TestFrameworkRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import common.sbt_testing.*

final case class FinishedTask(name: String, events: collection.Seq[Event], failures: collection.Set[String])

final class BasicTestRunner(framework: Framework, classLoader: ClassLoader, parallel: Boolean, logger: Logger) extends TestFrameworkRunner:
final class BasicTestRunner(framework: Framework, classLoader: ClassLoader, parallel: Boolean, parallelN: Int, logger: Logger) extends TestFrameworkRunner:
def execute(tests: Seq[TestDefinition], scopeAndTestName: String, arguments: Seq[String]) =
ClassLoaders.withContextClassLoader(classLoader) {
TestHelper.withRunner(framework, scopeAndTestName, classLoader, arguments) { runner =>
Expand All @@ -24,12 +24,12 @@ final class BasicTestRunner(framework: Framework, classLoader: ClassLoader, para
reporter.pre(framework, tasks)
given taskExecutor: TestTaskExecutor = TestTaskExecutor(logger)

val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel)
val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel, parallelN = parallelN)
TestFrameworkRunner.report(tasksAndEvents, failures)
}
}

final class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () => ClassLoader, parallel: Boolean, logger: Logger)
final class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () => ClassLoader, parallel: Boolean, parallelN: Int, logger: Logger)
extends TestFrameworkRunner:
def execute(tests: Seq[TestDefinition], scopeAndTestName: String, arguments: Seq[String]) =
given reporter: TestReporter = TestReporter(logger)
Expand All @@ -48,11 +48,11 @@ final class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: ()

for test <- tests do
val classLoader = classLoaderProvider()
val isolatedFramework = TestFrameworkLoader(classLoader, logger).load(framework.getClass.getName).get
val isolatedFramework = TestFrameworkLoader(classLoader).load(framework.getClass.getName).get
TestHelper.withRunner(isolatedFramework, scopeAndTestName, classLoader, arguments) { runner =>
ClassLoaders.withContextClassLoader(classLoader) {
val tasks = runner.tasks(Array(TestHelper.taskDef(test, scopeAndTestName)))
val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel)
val (tasksAndEvents, failures) = TestFrameworkRunner.run(tasks, parallel = parallel, parallelN = parallelN)
totalTasksAndEvents ++= tasksAndEvents
totalFailures ++= failures
}
Expand Down Expand Up @@ -109,12 +109,15 @@ sealed trait TestFrameworkRunner:
def execute(tests: Seq[TestDefinition], scopeAndTestName: String, arguments: Seq[String]): Boolean

object TestFrameworkRunner:
def run(tasks: collection.Seq[Task], parallel: Boolean)(
def run(tasks: collection.Seq[Task], parallel: Boolean, parallelN: Int)(
using taskExecutor: TestTaskExecutor,
reporter: TestReporter
): (mutable.ListBuffer[(String, collection.Seq[Event])], collection.Set[String]) =
val finishedTasks =
if parallel then Await.result(Future.sequence(tasks.map(t => Future(runTask(t)))), Duration.Inf)
if parallel then
val fut = Future.traverse(tasks.grouped(parallelN)): xs =>
Future.sequence(xs.map(t => Future(runTask(t))))
Await.result(fut, Duration.Inf).flatten
else tasks.map(runTask(_))

val failures = mutable.Set.empty[String]
Expand Down
8 changes: 5 additions & 3 deletions scala/workers/zinc/test/TestRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ object TestRunnerArguments:

final case class TestWorkArguments(
parallel: Boolean = false,
parallelN: Int = 1,
apis: Path = Paths.get("."),
subprocessExec: Path = Paths.get("."),
isolation: Isolation = Isolation.None,
Expand All @@ -71,6 +72,7 @@ object TestWorkArguments:

private val parser = OParser.sequence(
opt[Boolean]("parallel").optional().action((v, c) => c.copy(parallel = v)),
opt[Int]("parallel-n").optional().action((v, c) => c.copy(parallelN = Math.max(1, v))),
opt[File]("apis").required().action((f, c) => c.copy(apis = f.toPath)).text("APIs file"),
opt[File]("subprocess_exec").optional().action((f, c) => c.copy(subprocessExec = f.toPath)).text("Executable for SubprocessTestRunner"),
opt[Isolation]("isolation").optional().action((iso, c) => c.copy(isolation = iso)).text("Test isolation"),
Expand Down Expand Up @@ -133,7 +135,7 @@ object TestRunner:
ProtobufReaders(ReadMapper.getEmptyMapper, Schema.Version.V1_1).fromApis(shouldStoreApis = true)(raw)
catch case NonFatal(e) => throw Exception(s"Failed to load APIs from $apisFile", e)

val loader = TestFrameworkLoader(classLoader, logger)
val loader = TestFrameworkLoader(classLoader)
val frameworks = workArgs.frameworks.flatMap(loader.load)

val testClass = sys.env
Expand Down Expand Up @@ -163,11 +165,11 @@ object TestRunner:
case Isolation.ClassLoader =>
val urls = classpath.filterNot(sharedClasspath.toSet).map(_.toUri.toURL).toArray
def classLoaderProvider() = URLClassLoader(urls, sharedClassLoader)
ClassLoaderTestRunner(framework, classLoaderProvider, parallel = workArgs.parallel, logger)
ClassLoaderTestRunner(framework, classLoaderProvider, parallel = workArgs.parallel, parallelN = workArgs.parallelN, logger)
case Isolation.Process =>
val executable = runPath.resolve(workArgs.subprocessExec)
ProcessTestRunner(framework, classpath, ProcessCommand(executable.toString, runArgs.subprocessArg), logger)
case Isolation.None => BasicTestRunner(framework, classLoader, parallel = workArgs.parallel, logger)
case Isolation.None => BasicTestRunner(framework, classLoader, parallel = workArgs.parallel, parallelN = workArgs.parallelN, logger)
try runner.execute(filteredTests, testScopeAndName.getOrElse(""), runArgs.frameworkArgs)
catch
case e: Throwable =>
Expand Down

0 comments on commit ea1040c

Please sign in to comment.