Skip to content

Commit

Permalink
Use an outgoing transition to ensure Scala targets' dependencies are …
Browse files Browse the repository at this point in the history
…built once
  • Loading branch information
Jaden Peterson committed Nov 14, 2024
1 parent 302339a commit 85dd91b
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 36 deletions.
6 changes: 5 additions & 1 deletion rules/common/private/utils.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def write_launcher(
# runfiles_enabled = ctx.configuration.runfiles_enabled()
runfiles_enabled = False

java_runtime_info = ctx.attr._target_jdk[java_common.JavaRuntimeInfo]
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition is a
# 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of
# elements in this list is unspecified."
java_runtime_info = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo]
java_executable = java_runtime_info.java_executable_runfiles_path
if not paths.is_absolute(java_executable):
java_executable = workspace_name + "/" + java_executable
Expand Down
7 changes: 6 additions & 1 deletion rules/private/phases/phase_binary_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def phase_binary_launcher(ctx, g):
files = inputs + files,
transitive_files = depset(
order = "default",
transitive = [ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars],

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
transitive = [ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars],
),
collect_default = True,
),
Expand Down
10 changes: 8 additions & 2 deletions rules/private/phases/phase_javainfo.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@rules_java//java/common:java_common.bzl", "java_common")
load("@rules_java//toolchains:toolchain_utils.bzl", "find_java_toolchain")
load(
"@rules_scala_annex//rules:providers.bzl",
Expand Down Expand Up @@ -29,7 +30,12 @@ def phase_javainfo(ctx, g):
ctx.actions,
jar = ctx.outputs.jar,
target_label = ctx.label,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)

source_jar_name = ctx.outputs.jar.basename.replace(".jar", "-src.jar")
Expand All @@ -42,7 +48,7 @@ def phase_javainfo(ctx, g):
ctx.actions,
output_source_jar = output_source_jar,
sources = ctx.files.srcs,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)

java_info = JavaInfo(
Expand Down
10 changes: 7 additions & 3 deletions rules/private/phases/phase_test_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ load("//rules/common:private/utils.bzl", _collect = "collect", _write_launcher =
#

def phase_test_launcher(ctx, g):
files = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store]
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition is a
# 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of
# elements in this list is unspecified."
files = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store]

coverage_replacements = {}
coverage_runner_jars = depset(direct = [])
Expand All @@ -26,7 +30,7 @@ def phase_test_launcher(ctx, g):
coverage_replacements[jar] if jar in coverage_replacements else jar
for jar in g.javainfo.java_info.transitive_runtime_jars.to_list()
])
runner_jars = depset(transitive = [ctx.attr.runner[JavaInfo].transitive_runtime_jars, coverage_runner_jars])
runner_jars = depset(transitive = [ctx.attr.runner[0][JavaInfo].transitive_runtime_jars, coverage_runner_jars])
all_jars = [test_jars, runner_jars]

args = ctx.actions.args()
Expand All @@ -38,7 +42,7 @@ def phase_test_launcher(ctx, g):
args.add_all("--shared_classpath", shared_deps.transitive_runtime_jars, map_each = _test_launcher_short_path)
elif ctx.attr.isolation == "process":
subprocess_executable = ctx.actions.declare_file("{}/subprocess".format(ctx.label.name))
subprocess_runner_jars = ctx.attr.subprocess_runner[JavaInfo].transitive_runtime_jars
subprocess_runner_jars = ctx.attr.subprocess_runner[0][JavaInfo].transitive_runtime_jars
all_jars.append(subprocess_runner_jars)
files += _write_launcher(
ctx,
Expand Down
6 changes: 5 additions & 1 deletion rules/private/phases/phase_zinc_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def phase_zinc_compile(ctx, g):
javacopts = [
ctx.expand_location(option, ctx.attr.data)
for option in ctx.attr.javacopts + java_common.default_javac_opts(
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition
# is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already.
# The order of elements in this list is unspecified."
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)
]

Expand Down
46 changes: 35 additions & 11 deletions rules/register_toolchain.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@rules_scala_annex_scala_toolchain//:default.bzl", "default_scala_toolchain_name")
load(
"//rules:providers.bzl",
"CodeCoverageConfiguration",
Expand All @@ -14,6 +15,9 @@ load(
"phase_zinc_depscheck",
)

original_scala_toolchain_setting = "@rules_scala_annex_scala_toolchain//:original-scala-toolchain"
scala_toolchain_setting = "@rules_scala_annex_scala_toolchain//:scala-toolchain"

def _bootstrap_configuration_impl(ctx):
return [
platform_common.ToolchainInfo(
Expand Down Expand Up @@ -189,7 +193,7 @@ def _make_register_toolchain(configuration_rule):
native.config_setting(
name = "{}-setting".format(name),
flag_values = {
"@rules_scala_annex_scala_toolchain//:scala-toolchain": name,
scala_toolchain_setting: name,
},
)

Expand All @@ -206,21 +210,41 @@ def _make_register_toolchain(configuration_rule):
register_bootstrap_toolchain = _make_register_toolchain(_bootstrap_configuration)
register_zinc_toolchain = _make_register_toolchain(_zinc_configuration)

def _scala_toolchain_transition_impl(_, attr):
if attr.scala_toolchain_name == "":
return {}
def _scala_toolchain_incoming_transition_impl(settings, attr):
# We set `original_scala_toolchain_setting` so we can reset the toolchain to its original value
# in `scala_toolchain_outgoing_transition`. That way, we can ensure every target is built under
# a single toolchain, thus preventing duplicate builds.
#
# This is inspired by what the rules_go folks are doing.
return {} if attr.scala_toolchain_name == "" else {
original_scala_toolchain_setting: settings[scala_toolchain_setting],
scala_toolchain_setting: attr.scala_toolchain_name,
}

scala_toolchain_incoming_transition = transition(
implementation = _scala_toolchain_incoming_transition_impl,
inputs = [scala_toolchain_setting],
outputs = [original_scala_toolchain_setting, scala_toolchain_setting],
)

def _scala_toolchain_outgoing_transition_impl(settings, _):
original_scala_toolchain = settings[original_scala_toolchain_setting]

return {
"@rules_scala_annex_scala_toolchain//:scala-toolchain": attr.scala_toolchain_name,
return {} if original_scala_toolchain == "" else {
# Although `original_scala_toolchain_setting` will be overridden in the incoming transition,
# we set it to "" so that non-Scala targets aren't built under different values of this
# setting. That way, they aren't built multiple times.
original_scala_toolchain_setting: "",
scala_toolchain_setting: original_scala_toolchain,
}

scala_toolchain_transition = transition(
implementation = _scala_toolchain_transition_impl,
inputs = [],
outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
scala_toolchain_outgoing_transition = transition(
implementation = _scala_toolchain_outgoing_transition_impl,
inputs = [original_scala_toolchain_setting],
outputs = [original_scala_toolchain_setting, scala_toolchain_setting],
)

scala_toolchain_transition_attributes = {
scala_toolchain_attributes = {
"scala_toolchain_name": attr.string(
doc = "The name of the Scala toolchain to use for this target (as provided to `register_*_toolchain`)",
),
Expand Down
Loading

0 comments on commit 85dd91b

Please sign in to comment.