Skip to content

Commit

Permalink
Use an outgoing transition to ensure Scala targets' dependencies are …
Browse files Browse the repository at this point in the history
…built once
  • Loading branch information
Jaden Peterson committed Nov 13, 2024
1 parent 561a75d commit 7bab259
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 29 deletions.
6 changes: 5 additions & 1 deletion rules/common/private/utils.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def write_launcher(
# runfiles_enabled = ctx.configuration.runfiles_enabled()
runfiles_enabled = False

java_runtime_info = ctx.attr._target_jdk[java_common.JavaRuntimeInfo]
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition is a
# 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of
# elements in this list is unspecified."
java_runtime_info = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo]
java_executable = java_runtime_info.java_executable_runfiles_path
if not paths.is_absolute(java_executable):
java_executable = workspace_name + "/" + java_executable
Expand Down
7 changes: 6 additions & 1 deletion rules/private/phases/phase_binary_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def phase_binary_launcher(ctx, g):
files = inputs + files,
transitive_files = depset(
order = "default",
transitive = [ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars],

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
transitive = [ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars],
),
collect_default = True,
),
Expand Down
10 changes: 8 additions & 2 deletions rules/private/phases/phase_javainfo.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@rules_java//java/common:java_common.bzl", "java_common")
load("@rules_java//toolchains:toolchain_utils.bzl", "find_java_toolchain")
load(
"@rules_scala_annex//rules:providers.bzl",
Expand Down Expand Up @@ -29,7 +30,12 @@ def phase_javainfo(ctx, g):
ctx.actions,
jar = ctx.outputs.jar,
target_label = ctx.label,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)

source_jar_name = ctx.outputs.jar.basename.replace(".jar", "-src.jar")
Expand All @@ -42,7 +48,7 @@ def phase_javainfo(ctx, g):
ctx.actions,
output_source_jar = output_source_jar,
sources = ctx.files.srcs,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)

java_info = JavaInfo(
Expand Down
10 changes: 7 additions & 3 deletions rules/private/phases/phase_test_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ load("//rules/common:private/utils.bzl", _collect = "collect", _write_launcher =
#

def phase_test_launcher(ctx, g):
files = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store]
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition is a
# 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of
# elements in this list is unspecified."
files = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store]

coverage_replacements = {}
coverage_runner_jars = depset(direct = [])
Expand All @@ -26,7 +30,7 @@ def phase_test_launcher(ctx, g):
coverage_replacements[jar] if jar in coverage_replacements else jar
for jar in g.javainfo.java_info.transitive_runtime_jars.to_list()
])
runner_jars = depset(transitive = [ctx.attr.runner[JavaInfo].transitive_runtime_jars, coverage_runner_jars])
runner_jars = depset(transitive = [ctx.attr.runner[0][JavaInfo].transitive_runtime_jars, coverage_runner_jars])
all_jars = [test_jars, runner_jars]

args = ctx.actions.args()
Expand All @@ -38,7 +42,7 @@ def phase_test_launcher(ctx, g):
args.add_all("--shared_classpath", shared_deps.transitive_runtime_jars, map_each = _test_launcher_short_path)
elif ctx.attr.isolation == "process":
subprocess_executable = ctx.actions.declare_file("{}/subprocess".format(ctx.label.name))
subprocess_runner_jars = ctx.attr.subprocess_runner[JavaInfo].transitive_runtime_jars
subprocess_runner_jars = ctx.attr.subprocess_runner[0][JavaInfo].transitive_runtime_jars
all_jars.append(subprocess_runner_jars)
files += _write_launcher(
ctx,
Expand Down
6 changes: 5 additions & 1 deletion rules/private/phases/phase_zinc_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def phase_zinc_compile(ctx, g):
javacopts = [
ctx.expand_location(option, ctx.attr.data)
for option in ctx.attr.javacopts + java_common.default_javac_opts(
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition
# is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already.
# The order of elements in this list is unspecified."
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)
]

Expand Down
20 changes: 16 additions & 4 deletions rules/register_toolchain.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@rules_scala_annex_scala_toolchain//:default.bzl", "default_scala_toolchain_name")
load(
"//rules:providers.bzl",
"CodeCoverageConfiguration",
Expand Down Expand Up @@ -206,21 +207,32 @@ def _make_register_toolchain(configuration_rule):
register_bootstrap_toolchain = _make_register_toolchain(_bootstrap_configuration)
register_zinc_toolchain = _make_register_toolchain(_zinc_configuration)

def _scala_toolchain_transition_impl(_, attr):
def _scala_toolchain_incoming_transition_impl(settings, attr):
if attr.scala_toolchain_name == "":
return {}

return {
"@rules_scala_annex_scala_toolchain//:scala-toolchain": attr.scala_toolchain_name,
}

scala_toolchain_transition = transition(
implementation = _scala_toolchain_transition_impl,
scala_toolchain_incoming_transition = transition(
implementation = _scala_toolchain_incoming_transition_impl,
inputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
)

def _scala_toolchain_outgoing_transition_impl(_1, _2):
return {
"@rules_scala_annex_scala_toolchain//:scala-toolchain": default_scala_toolchain_name,
}

scala_toolchain_outgoing_transition = transition(
implementation = _scala_toolchain_outgoing_transition_impl,
inputs = [],
outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
)

scala_toolchain_transition_attributes = {
scala_toolchain_attributes = {
"scala_toolchain_name": attr.string(
doc = "The name of the Scala toolchain to use for this target (as provided to `register_*_toolchain`)",
),
Expand Down
55 changes: 41 additions & 14 deletions rules/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ load(":jvm.bzl", _labeled_jars = "labeled_jars")
load(":providers.bzl", _ScalaRulePhase = "ScalaRulePhase")
load(
":register_toolchain.bzl",
_scala_toolchain_transition = "scala_toolchain_transition",
_scala_toolchain_transition_attributes = "scala_toolchain_transition_attributes",
_scala_toolchain_attributes = "scala_toolchain_attributes",
_scala_toolchain_incoming_transition = "scala_toolchain_incoming_transition",
_scala_toolchain_outgoing_transition = "scala_toolchain_outgoing_transition",
)

_compile_private_attributes = {
"_java_toolchain": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = Label("@bazel_tools//tools/jdk:current_java_toolchain"),
),
"_host_javabase": attr.label(
Expand Down Expand Up @@ -77,6 +79,7 @@ _compile_private_attributes = {

_compile_attributes = {
"srcs": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The source Scala and Java files (and `-sources.jar` `.srcjar` `-src.jar` files of those).",
allow_files = [
".scala",
Expand All @@ -88,10 +91,12 @@ _compile_attributes = {
flags = ["DIRECT_COMPILE_TIME_INPUT"],
),
"data": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The additional runtime files needed by this library.",
allow_files = True,
),
"deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
aspects = [
_labeled_jars,
_coverage_replacements_provider.aspect,
Expand All @@ -100,21 +105,25 @@ _compile_attributes = {
providers = [JavaInfo],
),
"deps_used_whitelist": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM library dependencies to always consider used for `scala_deps_used` checks.",
providers = [JavaInfo],
),
"deps_unused_whitelist": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM library dependencies to always consider unused for `scala_deps_direct` checks.",
providers = [JavaInfo],
),
"runtime_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM runtime-only library dependencies.",
providers = [JavaInfo],
),
"javacopts": attr.string_list(
doc = "The Javac options.",
),
"plugins": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The Scalac plugins.",
providers = [JavaInfo],
),
Expand All @@ -123,10 +132,12 @@ _compile_attributes = {
),
"resources": attr.label_list(
allow_files = True,
cfg = _scala_toolchain_outgoing_transition,
doc = "The files to include as classpath resources.",
),
"resource_jars": attr.label_list(
allow_files = [".jar"],
cfg = _scala_toolchain_outgoing_transition,
doc = "The JARs to merge into the output JAR.",
),
"scalacopts": attr.string_list(
Expand All @@ -139,6 +150,7 @@ _library_attributes = {
aspects = [
_coverage_replacements_provider.aspect,
],
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM libraries to add as dependencies to any libraries dependent on this one.",
providers = [JavaInfo],
),
Expand All @@ -157,17 +169,20 @@ _runtime_attributes = {
doc = "The JVM runtime flags.",
),
"runtime_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM runtime-only library dependencies.",
providers = [JavaInfo],
),
}

_runtime_private_attributes = {
"_target_jdk": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = Label("@bazel_tools//tools/jdk:current_java_runtime"),
providers = [java_common.JavaRuntimeInfo],
),
"_java_stub_template": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = Label("@anx_java_stub_template//file"),
allow_single_file = True,
),
Expand Down Expand Up @@ -243,11 +258,11 @@ def make_scala_library(*extras):
_compile_attributes,
_compile_private_attributes,
_library_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
_extras_attributes(extras),
*[extra["attrs"] for extra in extras]
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = "Compiles a Scala JVM library.",
implementation = _scala_library_implementation,
outputs = _dicts.add(
Expand All @@ -271,7 +286,7 @@ def make_scala_binary(*extras):
_compile_private_attributes,
_runtime_attributes,
_runtime_private_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
{
"main_class": attr.string(
doc = "The main class. If not provided, it will be inferred by its type signature.",
Expand All @@ -280,7 +295,7 @@ def make_scala_binary(*extras):
_extras_attributes(extras),
*[extra["attrs"] for extra in extras]
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Compiles and links a Scala JVM executable.
Expand Down Expand Up @@ -317,7 +332,7 @@ def make_scala_test(*extras):
_compile_private_attributes,
_runtime_attributes,
_runtime_private_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
_testing_private_attributes,
{
"isolation": attr.string(
Expand All @@ -331,6 +346,7 @@ def make_scala_test(*extras):
),
"scalacopts": attr.string_list(doc = "Options to pass to scalac."),
"shared_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "If isolation is \"classloader\", the list of deps to keep loaded between tests",
providers = [JavaInfo],
),
Expand All @@ -345,13 +361,19 @@ def make_scala_test(*extras):
],
doc = "The list of test frameworks to check for. These should conform to the sbt test interface (https://github.com/sbt/test-interface).",
),
"runner": attr.label(default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test"),
"subprocess_runner": attr.label(default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess"),
"runner": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test",
),
"subprocess_runner": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess",
),
},
_extras_attributes(extras),
*[extra["attrs"] for extra in extras]
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Compiles and links a collection of Scala tests.
Expand Down Expand Up @@ -396,13 +418,15 @@ _scala_repl_private_attributes = _dicts.add(
scala_repl = rule(
attrs = _dicts.add(
_scala_repl_private_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
{
"data": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The additional runtime files needed by this REPL.",
allow_files = True,
),
"deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "Dependencies that should be made available to the REPL.",
providers = [JavaInfo],
),
Expand All @@ -412,7 +436,7 @@ scala_repl = rule(
"scalacopts": attr.string_list(doc = "Options to pass to scalac."),
},
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Launches a REPL with all given dependencies available.
Expand Down Expand Up @@ -469,14 +493,16 @@ Use this only for libraries with macros. Otherwise, use `java_import`.""",

scaladoc = rule(
attrs = _dicts.add(
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
_scaladoc_private_attributes,
{
"compiler_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "JVM targets that should be included on the compile classpath.",
providers = [JavaInfo],
),
"deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "Dependencies that should be made available to the Scaladoc tool. These may include libraries referenced in Scaladoc or public signatures.",
providers = [JavaInfo],
),
Expand All @@ -488,13 +514,14 @@ scaladoc = rule(
"-sources.jar",
"-src.jar",
],
cfg = _scala_toolchain_outgoing_transition,
doc = "Sources from which to generate Scaladoc. These may include `*.java` files, `*.scala` files, and source JARs.",
),
"scalacopts": attr.string_list(doc = "Options to pass to scalac."),
"title": attr.string(doc = "The name of the project. If none is provided, the target label will be used."),
},
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = "Generates Scaladoc.",
implementation = _scaladoc_implementation,
toolchains = [
Expand Down
7 changes: 6 additions & 1 deletion rules/scala/private/repl.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def scala_repl_implementation(ctx):
runfiles = ctx.runfiles(
collect_default = True,
collect_data = True,
files = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files.to_list(),

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
files = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files.to_list(),
transitive_files = files,
),
),
Expand Down
Loading

0 comments on commit 7bab259

Please sign in to comment.