Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sends verbosity from the worker protocol to the worker and enable Java toolchain multiplex sandboxing #55

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ default_java_toolchain(
name = "repository_default_toolchain_21",
configuration = DEFAULT_TOOLCHAIN_CONFIGURATION,
java_runtime = "@rules_java//toolchains:remotejdk_21",
javac_supports_worker_multiplex_sandboxing = True,
source_version = "21",
target_version = "21",
)
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ protobuf_deps()
# rules_java
http_archive(
name = "rules_java",
sha256 = "80e61f508ff79a3fde4a549b8b1f6ec7f8a82c259e51240a4403e5be36f88142",
sha256 = "41131de4417de70b9597e6ebd515168ed0ba843a325dc54a81b92d7af9a7b3ea",
urls = [
"https://github.com/bazelbuild/rules_java/releases/download/7.6.4/rules_java-7.6.4.tar.gz",
"https://github.com/bazelbuild/rules_java/releases/download/7.9.0/rules_java-7.9.0.tar.gz",
],
)

Expand Down
76 changes: 52 additions & 24 deletions rules/common/private/utils.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@bazel_skylib//lib:dicts.bzl", "dicts")
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_skylib//lib:shell.bzl", "shell")

#
# Helper utilities
Expand Down Expand Up @@ -31,6 +32,19 @@ def _strip_margin_line(line, delim):

_SINGLE_JAR_MNEMONIC = "SingleJar"

def _format_jacoco_metadata_file(runfiles_enabled, workspace_prefix, metadata_file):
if runfiles_enabled:
return "export JACOCO_METADATA_JAR=\"$JAVA_RUNFILES/{}/{}\"".format(workspace_prefix, metadata_file.short_path)

return "export JACOCO_METADATA_JAR=$(rlocation " + paths.normalize(workspace_prefix + metadata_file.short_path) + ")"

# This is from the Starlark Java builtins in Bazel
def _format_classpath_entry(runfiles_enabled, workspace_prefix, file):
if runfiles_enabled:
return "${RUNPATH}" + file.short_path

return "$(rlocation " + paths.normalize(workspace_prefix + file.short_path) + ")"

def write_launcher(
ctx,
prefix,
Expand All @@ -40,40 +54,48 @@ def write_launcher(
jvm_flags,
extra = "",
jacoco_classpath = None):
"""Macro that writes out a launcher script shell script.
"""Macro that writes out a launcher script shell script. Some of this is from Bazel's Starlark Java builtins.
Args:
runtime_classpath: File containing the classpath required to launch this java target.
main_class: the main class to launch.
jvm_flags: The flags that should be passed to the jvm.
args: Args that should be passed to the Binary.
"""
workspace_name = ctx.workspace_name
workspace_prefix = workspace_name + ("/" if workspace_name else "")

classpath_args = ctx.actions.args()
classpath_args.add_joined(runtime_classpath, format_each = "${RUNPATH}%s", join_with = ":", map_each = _short_path)
classpath_args.set_param_file_format("multiline")
classpath_file = ctx.actions.declare_file("{}classpath.params".format(prefix))
ctx.actions.write(classpath_file, classpath_args)

classpath = "\"$(eval echo \"$(cat ${{RUNPATH}}{})\")\"".format(classpath_file.short_path)

jvm_flags = " ".join(jvm_flags)
template = ctx.file._java_stub_template
# TODO: can we get this info?
# runfiles_enabled = ctx.configuration.runfiles_enabled()
runfiles_enabled = False

java_executable = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].java_executable_runfiles_path
java_path = str(java_executable)
if paths.is_absolute(java_path):
javabin = java_path
java_runtime_info = ctx.attr._target_jdk[java_common.JavaRuntimeInfo]
java_executable = java_runtime_info.java_executable_runfiles_path
if not paths.is_absolute(java_executable):
java_executable = workspace_name + "/" + java_executable
java_executable = paths.normalize(java_executable)

if runfiles_enabled:
prefix = "" if paths.is_absolute(java_executable) else "${JAVA_RUNFILES}/"
javabin = "JAVABIN=${JAVABIN:-" + prefix + java_executable + "}"
else:
javabin = "$JAVA_RUNFILES/{}/{}".format(ctx.workspace_name, java_executable)
javabin = "JAVABIN=${JAVABIN:-$(rlocation " + java_executable + ")}"

template_dict = ctx.actions.template_dict()
template_dict.add_joined(
"%classpath%",
runtime_classpath,
map_each = lambda file: _format_classpath_entry(runfiles_enabled, workspace_prefix, file),
join_with = ctx.configuration.host_path_separator,
format_joined = "\"%s\"",
allow_closure = True,
)

base_substitutions = {
"%classpath%": classpath,
"%javabin%": "JAVABIN=\"{}\"\n{}".format(javabin, extra),
"%jvm_flags%": jvm_flags,
"%needs_runfiles%": "1" if runfiles_enabled else "",
"%runfiles_manifest_only%": "1" if runfiles_enabled else "",
"%workspace_prefix%": ctx.workspace_name + "/",
"%workspace_prefix%": workspace_prefix,
"%javabin%": "{}\n{}".format(javabin, extra),
"%needs_runfiles%": "0" if paths.is_absolute(java_runtime_info.java_executable_exec_path) else "1",
"%jvm_flags%": " ".join(jvm_flags),
"%test_runtime_classpath_file%": "",
}

Expand All @@ -86,9 +108,14 @@ def write_launcher(
for jar in jacoco_classpath
]))
more_outputs = [metadata_file]

template_dict.add(
"%set_jacoco_metadata%",
_format_jacoco_metadata_file(runfiles_enabled, workspace_prefix, metadata_file),
)

more_substitutions = {
"%java_start_class%": "com.google.testing.coverage.JacocoCoverageRunner",
"%set_jacoco_metadata%": "export JACOCO_METADATA_JAR=\"$JAVA_RUNFILES/{}/{}\"".format(ctx.workspace_name, metadata_file.short_path),
"%set_jacoco_main_class%": """export JACOCO_MAIN_CLASS={}""".format(main_class),
"%set_jacoco_java_runfiles_root%": """export JACOCO_JAVA_RUNFILES_ROOT=$JAVA_RUNFILES/{}/""".format(ctx.workspace_name),
"%set_java_coverage_new_implementation%": """export JAVA_COVERAGE_NEW_IMPLEMENTATION=YES""",
Expand All @@ -104,13 +131,14 @@ def write_launcher(
}

ctx.actions.expand_template(
template = template,
template = ctx.file._java_stub_template,
output = output,
substitutions = dicts.add(base_substitutions, more_substitutions),
computed_substitutions = template_dict,
is_executable = True,
)

return more_outputs + [classpath_file]
return more_outputs

def safe_name(value):
return "".join([value[i] if value[i].isalnum() or value[i] == "." else "_" for i in range(len(value))])
Expand Down
2 changes: 1 addition & 1 deletion rules/scala_proto/private/ScalaProtoWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ object ScalaProtoWorker extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

protected def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
protected def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = ScalaProtoRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
2 changes: 1 addition & 1 deletion rules/scalafmt/scalafmt/ScalafmtRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object ScalafmtRunner extends WorkerMain[Unit] {

protected[this] def init(args: Option[Array[String]]): Unit = {}

protected[this] def work(worker: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
protected[this] def work(worker: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = ScalafmtRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream

protected[this] def init(args: Option[Array[String]]): S

protected[this] def work(ctx: S, args: Array[String], out: PrintStream, workDir: Path): Unit
protected[this] def work(ctx: S, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're getting the arguments and verbosity directly from the work request, should we just pass that to the worker directly instead of copying so many of its fields?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should because this does create a nice interface all the workers must adhere to. It does sometimes mean updates when the protocol changes. Good news is that tends to happen rather infrequently.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can see how this forces worker implementers to think about certain things. Nevermind, then.


protected[this] var isWorker = false

Expand Down Expand Up @@ -101,6 +101,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream
} else {
val args = request.getArgumentsList.toArray(Array.empty[String])
val sandboxDir = Path.of(request.getSandboxDir())
val verbosity = request.getVerbosity()
System.err.println(s"WorkRequest $requestId received with args: ${request.getArgumentsList}")

// We go through this hullabaloo with output streams being defined out here, so we can
Expand All @@ -114,7 +115,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream
outStream = new ByteArrayOutputStream
out = new PrintStream(outStream)
try {
work(ctx, args, out, sandboxDir)
work(ctx, args, out, sandboxDir, verbosity)
0
} catch {
case e @ AnnexWorkerError(code, _, _) =>
Expand Down Expand Up @@ -210,6 +211,7 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream
args.toArray,
out,
workDir = Path.of(""),
verbosity = 0,
)
} catch {
// This error means the work function encountered an error that we want to not be caught
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ import java.nio.file.Path

object BloopRunner extends WorkerMain[Unit] {
override def init(args: Option[Array[String]]): Unit = ()
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = Bloop
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = Bloop
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object DepsRunner extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = DepsRunnerRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object JacocoInstrumenter extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = JacocoRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))

val jacoco = new Instrumenter(new OfflineInstrumentationAccessGenerator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ object ZincRunner extends WorkerMain[ZincRunnerWorkerConfig] {
args: Array[String],
out: PrintStream,
workDir: Path,
verbosity: Int,
): Unit = {
val workRequest = CommonArguments(ArgsUtil.parseArgsOrFailSafe(args, parser, out), workDir)
InterruptUtil.throwIfInterrupted()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ object DocRunner extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
val workRequest = DocRequest(workDir, ArgsUtil.parseArgsOrFailSafe(args, argParser, out))
InterruptUtil.throwIfInterrupted()

Expand Down
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ default_java_toolchain(
name = "repository_default_toolchain_21",
configuration = DEFAULT_TOOLCHAIN_CONFIGURATION,
java_runtime = "@rules_java//toolchains:remotejdk_21",
javac_supports_worker_multiplex_sandboxing = True,
source_version = "21",
target_version = "21",
)
2 changes: 1 addition & 1 deletion tests/cancellation/RunnerForCancelSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerForCancelSpec(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
var interrupted = false
var i = 0

Expand Down
2 changes: 1 addition & 1 deletion tests/worker-error/RunnerThatThrowsError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerThatThrowsError(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
throw new Error()
}
}
2 changes: 1 addition & 1 deletion tests/worker-error/RunnerThatThrowsException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerThatThrowsException(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
throw new Exception()
}
}
2 changes: 1 addition & 1 deletion tests/worker-error/RunnerThatThrowsFatalError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunnerThatThrowsFatalError(stdin: InputStream, stdout: PrintStream)

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
throw new OutOfMemoryError()
}
}
20 changes: 20 additions & 0 deletions tests/worker-verbosity/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@rules_scala_annex//rules:scala.bzl", "scala_binary")
load("verbosity_spec_worker_run.bzl", "verbosity_spec_worker_run")

scala_binary(
name = "verbosity-spec-worker",
srcs = [
"RunnerThatPrintsVerbosity.scala",
],
scala = "//scala:2_13",
tags = ["manual"],
deps = [
"@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sandbox",
"@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/worker",
],
)

verbosity_spec_worker_run(
name = "verbosity-spec-target",
verbosity_spec_worker = ":verbosity-spec-worker",
)
15 changes: 15 additions & 0 deletions tests/worker-verbosity/RunnerThatPrintsVerbosity.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package anx.cancellation

import higherkindness.rules_scala.common.worker.WorkerMain
import higherkindness.rules_scala.common.sandbox.SandboxUtil

import java.io.{InputStream, PrintStream}
import java.nio.file.{Files, Path, Paths}

object RunnerThatPrintsVerbosity extends WorkerMain[Unit] {
override def init(args: Option[Array[String]]): Unit = ()
override def work(ctx: Unit, args: Array[String], out: PrintStream, workDir: Path, verbosity: Int): Unit = {
out.println(s"Verbosity: ${verbosity}")
Files.createFile(SandboxUtil.getSandboxPath(workDir, Paths.get(args(0))))
}
}
12 changes: 12 additions & 0 deletions tests/worker-verbosity/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash -e
. "$(dirname "$0")"/../common.sh

# We use modify_execution_info, nouse_action_cache, and bazel shutdown here
# in order to prevent the disk cache, skyframe cache, and persistent action cache
# from being used for the verbosity spec worker actions and thus getting the
# verbosity we want getting printed. The alternative is to bazel clean, which
# takes much longer.
bazel shutdown
bazel build --modify_execution_info="VerbositySpecWorkerRun=+no-cache" --nouse_action_cache :verbosity-spec-target |& grep -q "Verbosity: 0"
bazel shutdown
bazel build --modify_execution_info="VerbositySpecWorkerRun=+no-cache" --nouse_action_cache --worker_verbose :verbosity-spec-target |& grep -q "Verbosity: 10"
39 changes: 39 additions & 0 deletions tests/worker-verbosity/verbosity_spec_worker_run.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
def _impl(ctx):
foo_file = ctx.actions.declare_file("foo.txt")
outputs = [foo_file]

args = ctx.actions.args()
args.add(foo_file)
args.set_param_file_format("multiline")
args.use_param_file("@%s", use_always = True)

ctx.actions.run(
outputs = outputs,
arguments = [args],
mnemonic = "VerbositySpecWorkerRun",
execution_requirements = {
"supports-multiplex-workers": "1",
"supports-workers": "1",
"supports-multiplex-sandboxing": "1",
"supports-worker-cancellation": "1",
},
progress_message = "Running verbosity spec worker %{label}",
executable = ctx.executable.verbosity_spec_worker,
)

return [
DefaultInfo(files = depset(outputs)),
]

verbosity_spec_worker_run = rule(
implementation = _impl,
doc = "Runs a worker that prints the verbosity level it received from the work request",
attrs = {
"verbosity_spec_worker": attr.label(
executable = True,
cfg = "host",
allow_files = True,
default = Label(":verbosity-spec-worker"),
),
},
)
Loading