Skip to content

Commit

Permalink
Merge pull request #55 from lucidsoftware/more-worker-updates
Browse files Browse the repository at this point in the history
Sends verbosity from the worker protocol to the worker, enables Java toolchain multiplex sandboxing, and fixes the launcher to work with multiplex sandboxing
  • Loading branch information
jjudd authored Sep 9, 2024
2 parents 73faf81 + ef5f55e commit f23c160
Show file tree
Hide file tree
Showing 20 changed files with 157 additions and 38 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ default_java_toolchain(
name = "repository_default_toolchain_21",
configuration = DEFAULT_TOOLCHAIN_CONFIGURATION,
java_runtime = "@rules_java//toolchains:remotejdk_21",
javac_supports_worker_multiplex_sandboxing = True,
source_version = "21",
target_version = "21",
)
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ protobuf_deps()
# rules_java
http_archive(
name = "rules_java",
sha256 = "80e61f508ff79a3fde4a549b8b1f6ec7f8a82c259e51240a4403e5be36f88142",
sha256 = "41131de4417de70b9597e6ebd515168ed0ba843a325dc54a81b92d7af9a7b3ea",
urls = [
"https://github.com/bazelbuild/rules_java/releases/download/7.6.4/rules_java-7.6.4.tar.gz",
"https://github.com/bazelbuild/rules_java/releases/download/7.9.0/rules_java-7.9.0.tar.gz",
],
)

Expand Down
76 changes: 52 additions & 24 deletions rules/common/private/utils.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@bazel_skylib//lib:dicts.bzl", "dicts")
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_skylib//lib:shell.bzl", "shell")

#
# Helper utilities
Expand Down Expand Up @@ -31,6 +32,19 @@ def _strip_margin_line(line, delim):

_SINGLE_JAR_MNEMONIC = "SingleJar"

def _format_jacoco_metadata_file(runfiles_enabled, workspace_prefix, metadata_file):
if runfiles_enabled:
return "export JACOCO_METADATA_JAR=\"$JAVA_RUNFILES/{}/{}\"".format(workspace_prefix, metadata_file.short_path)

return "export JACOCO_METADATA_JAR=$(rlocation " + paths.normalize(workspace_prefix + metadata_file.short_path) + ")"

# This is from the Starlark Java builtins in Bazel
def _format_classpath_entry(runfiles_enabled, workspace_prefix, file):
if runfiles_enabled:
return "${RUNPATH}" + file.short_path

return "$(rlocation " + paths.normalize(workspace_prefix + file.short_path) + ")"

def write_launcher(
ctx,
prefix,
Expand All @@ -40,40 +54,48 @@ def write_launcher(
jvm_flags,
extra = "",
jacoco_classpath = None):
"""Macro that writes out a launcher script shell script.
"""Macro that writes out a launcher script shell script. Some of this is from Bazel's Starlark Java builtins.
Args:
runtime_classpath: File containing the classpath required to launch this java target.
main_class: the main class to launch.
jvm_flags: The flags that should be passed to the jvm.
args: Args that should be passed to the Binary.
"""
workspace_name = ctx.workspace_name
workspace_prefix = workspace_name + ("/" if workspace_name else "")

classpath_args = ctx.actions.args()
classpath_args.add_joined(runtime_classpath, format_each = "${RUNPATH}%s", join_with = ":", map_each = _short_path)
classpath_args.set_param_file_format("multiline")
classpath_file = ctx.actions.declare_file("{}classpath.params".format(prefix))
ctx.actions.write(classpath_file, classpath_args)

classpath = "\"$(eval echo \"$(cat ${{RUNPATH}}{})\")\"".format(classpath_file.short_path)

jvm_flags = " ".join(jvm_flags)
template = ctx.file._java_stub_template
# TODO: can we get this info?
# runfiles_enabled = ctx.configuration.runfiles_enabled()
runfiles_enabled = False

java_executable = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].java_executable_runfiles_path
java_path = str(java_executable)
if paths.is_absolute(java_path):
javabin = java_path
java_runtime_info = ctx.attr._target_jdk[java_common.JavaRuntimeInfo]
java_executable = java_runtime_info.java_executable_runfiles_path
if not paths.is_absolute(java_executable):
java_executable = workspace_name + "/" + java_executable
java_executable = paths.normalize(java_executable)

if runfiles_enabled:
prefix = "" if paths.is_absolute(java_executable) else "${JAVA_RUNFILES}/"
javabin = "JAVABIN=${JAVABIN:-" + prefix + java_executable + "}"
else:
javabin = "$JAVA_RUNFILES/{}/{}".format(ctx.workspace_name, java_executable)
javabin = "JAVABIN=${JAVABIN:-$(rlocation " + java_executable + ")}"

template_dict = ctx.actions.template_dict()
template_dict.add_joined(
"%classpath%",
runtime_classpath,
map_each = lambda file: _format_classpath_entry(runfiles_enabled, workspace_prefix, file),
join_with = ctx.configuration.host_path_separator,
format_joined = "\"%s\"",
allow_closure = True,
)

base_substitutions = {
"%classpath%": classpath,
"%javabin%": "JAVABIN=\"{}\"\n{}".format(javabin, extra),
"%jvm_flags%": jvm_flags,
"%needs_runfiles%": "1" if runfiles_enabled else "",
"%runfiles_manifest_only%": "1" if runfiles_enabled else "",
"%workspace_prefix%": ctx.workspace_name + "/",
"%workspace_prefix%": workspace_prefix,
"%javabin%": "{}\n{}".format(javabin, extra),
"%needs_runfiles%": "0" if paths.is_absolute(java_runtime_info.java_executable_exec_path) else "1",
"%jvm_flags%": " ".join(jvm_flags),
"%test_runtime_classpath_file%": "",
}

Expand All @@ -86,9 +108,14 @@ def write_launcher(
for jar in jacoco_classpath
]))
more_outputs = [metadata_file]

template_dict.add(
"%set_jacoco_metadata%",
_format_jacoco_metadata_file(runfiles_enabled, workspace_prefix, metadata_file),
)

more_substitutions = {
"%java_start_class%": "com.google.testing.coverage.JacocoCoverageRunner",
"%set_jacoco_metadata%": "export JACOCO_METADATA_JAR=\"$JAVA_RUNFILES/{}/{}\"".format(ctx.workspace_name, metadata_file.short_path),
"%set_jacoco_main_class%": """export JACOCO_MAIN_CLASS={}""".format(main_class),
"%set_jacoco_java_runfiles_root%": """export JACOCO_JAVA_RUNFILES_ROOT=$JAVA_RUNFILES/{}/""".format(ctx.workspace_name),
"%set_java_coverage_new_implementation%": """export JAVA_COVERAGE_NEW_IMPLEMENTATION=YES""",
Expand All @@ -104,13 +131,14 @@ def write_launcher(
}

ctx.actions.expand_template(
template = template,
template = ctx.file._java_stub_template,
output = output,
substitutions = dicts.add(base_substitutions, more_substitutions),
computed_substitutions = template_dict,
is_executable = True,
)

return more_outputs + [classpath_file]
return more_outputs

def safe_name(value):
return "".join([value[i] if value[i].isalnum() or value[i] == "." else "_" for i in range(len(value))])
Expand Down
2 changes: 1 addition & 1 deletion rules/scala_proto/private/ScalaProtoWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ object ScalaProtoWorker extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

protected def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
protected def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = ScalaProtoRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
2 changes: 1 addition & 1 deletion rules/scalafmt/scalafmt/ScalafmtRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object ScalafmtRunner extends WorkerMain[Unit] {

protected[this] def init(args: Option[Array[String]]): Unit = {}

protected[this] def work(worker: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
protected[this] def work(worker: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = ScalafmtRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream

protected[this] def init(args: Option[Array[String]]): S

protected[this] def work(ctx: S, args: Array[String], out: PrintStream, workDir: Path): Unit
protected[this] def work(ctx: S, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit

protected[this] var isWorker = false

Expand Down Expand Up @@ -101,6 +101,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream
} else {
val args = request.getArgumentsList.toArray(Array.empty[String])
val sandboxDir = Path.of(request.getSandboxDir())
val verbosity = request.getVerbosity()
System.err.println(s"WorkRequest $requestId received with args: ${request.getArgumentsList}")

// We go through this hullabaloo with output streams being defined out here, so we can
Expand All @@ -114,7 +115,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream
outStream = new ByteArrayOutputStream
out = new PrintStream(outStream)
try {
work(ctx, args, out, sandboxDir)
work(ctx, args, out, sandboxDir, verbosity)
0
} catch {
case e @ AnnexWorkerError(code, _, _) =>
Expand Down Expand Up @@ -210,6 +211,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream
args.toArray,
out,
workDir = Path.of(""),
verbosity = 0,
)
} catch {
// This error means the work function encountered an error that we want to not be caught
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ import java.nio.file.Path

object BloopRunner extends WorkerMain[Unit] {
override def init(args: Option[Array[String]]): Unit = ()
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = Bloop
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = Bloop
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object DepsRunner extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = DepsRunnerRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object JacocoInstrumenter extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = JacocoRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))

val jacoco = new Instrumenter(new OfflineInstrumentationAccessGenerator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ object ZincRunner extends WorkerMain[ZincRunnerWorkerConfig] {
args: Array[String],
out: PrintStream,
workDir: Path,
verbosity: Int,
): Unit = {
val workRequest = CommonArguments(ArgsUtil.parseArgsOrFailSafe(args, parser, out), workDir)
InterruptUtil.throwIfInterrupted()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ object DocRunner extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = DocRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ default_java_toolchain(
name = "repository_default_toolchain_21",
configuration = DEFAULT_TOOLCHAIN_CONFIGURATION,
java_runtime = "@rules_java//toolchains:remotejdk_21",
javac_supports_worker_multiplex_sandboxing = True,
source_version = "21",
target_version = "21",
)
2 changes: 1 addition & 1 deletion tests/cancellation/RunnerForCancelSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerForCancelSpec(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
var interrupted = false
var i = 0

Expand Down
2 changes: 1 addition & 1 deletion tests/worker-error/RunnerThatThrowsError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerThatThrowsError(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
throw new Error()
}
}
2 changes: 1 addition & 1 deletion tests/worker-error/RunnerThatThrowsException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerThatThrowsException(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
throw new Exception()
}
}
2 changes: 1 addition & 1 deletion tests/worker-error/RunnerThatThrowsFatalError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerThatThrowsFatalError(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
throw new OutOfMemoryError()
}
}
20 changes: 20 additions & 0 deletions tests/worker-verbosity/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@rules_scala_annex//rules:scala.bzl", "scala_binary")
load("verbosity_spec_worker_run.bzl", "verbosity_spec_worker_run")

scala_binary(
name = "verbosity-spec-worker",
srcs = [
"RunnerThatPrintsVerbosity.scala",
],
scala = "//scala:2_13",
tags = ["manual"],
deps = [
"@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sandbox",
"@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/worker",
],
)

verbosity_spec_worker_run(
name = "verbosity-spec-target",
verbosity_spec_worker = ":verbosity-spec-worker",
)
15 changes: 15 additions & 0 deletions tests/worker-verbosity/RunnerThatPrintsVerbosity.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package anx.cancellation

import higherkindness.rules_scala.common.worker.WorkerMain
import higherkindness.rules_scala.common.sandbox.SandboxUtil

import java.io.{InputStream, PrintStream}
import java.nio.file.{Files, Path, Paths}

object RunnerThatPrintsVerbosity extends WorkerMain[Unit] {
override def init(args: Option[Array[String]]): Unit = ()
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
out.println(s"Verbosity: ${verbosity}")
Files.createFile(SandboxUtil.getSandboxPath(workDir, Paths.get(args(0))))
}
}
12 changes: 12 additions & 0 deletions tests/worker-verbosity/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash -e
. "$(dirname "$0")"/../common.sh

# We use modify_execution_info, nouse_action_cache, and bazel shutdown here
# in order to prevent the disk cache, skyframe cache, and persistent action cache
# from being used for the verbosity spec worker actions and thus getting the
# verbosity we want getting printed. The alternative is to bazel clean, which
# takes much longer.
bazel shutdown
bazel build --modify_execution_info="VerbositySpecWorkerRun=+no-cache" --nouse_action_cache :verbosity-spec-target |& grep -q "Verbosity: 0"
bazel shutdown
bazel build --modify_execution_info="VerbositySpecWorkerRun=+no-cache" --nouse_action_cache --worker_verbose :verbosity-spec-target |& grep -q "Verbosity: 10"
39 changes: 39 additions & 0 deletions tests/worker-verbosity/verbosity_spec_worker_run.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
def _impl(ctx):
foo_file = ctx.actions.declare_file("foo.txt")
outputs = [foo_file]

args = ctx.actions.args()
args.add(foo_file)
args.set_param_file_format("multiline")
args.use_param_file("@%s", use_always = True)

ctx.actions.run(
outputs = outputs,
arguments = [args],
mnemonic = "VerbositySpecWorkerRun",
execution_requirements = {
"supports-multiplex-workers": "1",
"supports-workers": "1",
"supports-multiplex-sandboxing": "1",
"supports-worker-cancellation": "1",
},
progress_message = "Running verbosity spec worker %{label}",
executable = ctx.executable.verbosity_spec_worker,
)

return [
DefaultInfo(files = depset(outputs)),
]

verbosity_spec_worker_run = rule(
implementation = _impl,
doc = "Runs a worker that prints the verbosity level it received from the work request",
attrs = {
"verbosity_spec_worker": attr.label(
executable = True,
cfg = "host",
allow_files = True,
default = Label(":verbosity-spec-worker"),
),
},
)

0 comments on commit f23c160

Please sign in to comment.