Skip to content

Commit

Permalink
Use an outgoing transition to ensure Scala targets' dependencies are …
Browse files Browse the repository at this point in the history
…built once
  • Loading branch information
Jaden Peterson committed Nov 13, 2024
1 parent 05961dd commit 7b7115a
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 39 deletions.
6 changes: 5 additions & 1 deletion rules/common/private/utils.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def write_launcher(
# runfiles_enabled = ctx.configuration.runfiles_enabled()
runfiles_enabled = False

java_runtime_info = ctx.attr._target_jdk[java_common.JavaRuntimeInfo]
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition is a
# 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of
# elements in this list is unspecified."
java_runtime_info = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo]
java_executable = java_runtime_info.java_executable_runfiles_path
if not paths.is_absolute(java_executable):
java_executable = workspace_name + "/" + java_executable
Expand Down
7 changes: 6 additions & 1 deletion rules/private/phases/phase_binary_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def phase_binary_launcher(ctx, g):
files = inputs + files,
transitive_files = depset(
order = "default",
transitive = [ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars],

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
transitive = [ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars],
),
collect_default = True,
),
Expand Down
10 changes: 8 additions & 2 deletions rules/private/phases/phase_javainfo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load(
"find_java_runtime_toolchain",
"find_java_toolchain",
)
load("@rules_java//java/common:java_common.bzl", "java_common")
load(
"@rules_scala_annex//rules:providers.bzl",
_ScalaConfiguration = "ScalaConfiguration",
Expand Down Expand Up @@ -34,7 +35,12 @@ def phase_javainfo(ctx, g):
ctx.actions,
jar = ctx.outputs.jar,
target_label = ctx.label,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),

# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the
# transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it
# isn't already. The order of elements in this list is unspecified."
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)

source_jar_name = ctx.outputs.jar.basename.replace(".jar", "-src.jar")
Expand All @@ -47,7 +53,7 @@ def phase_javainfo(ctx, g):
ctx.actions,
output_source_jar = output_source_jar,
sources = ctx.files.srcs,
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)

java_info = JavaInfo(
Expand Down
10 changes: 7 additions & 3 deletions rules/private/phases/phase_test_launcher.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ load(
#

def phase_test_launcher(ctx, g):
files = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store]
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition is a
# 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of
# elements in this list is unspecified."
files = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store]

coverage_replacements = {}
coverage_runner_jars = depset(direct = [])
Expand All @@ -31,7 +35,7 @@ def phase_test_launcher(ctx, g):
coverage_replacements[jar] if jar in coverage_replacements else jar
for jar in g.javainfo.java_info.transitive_runtime_jars.to_list()
])
runner_jars = depset(transitive = [ctx.attr.runner[JavaInfo].transitive_runtime_jars, coverage_runner_jars])
runner_jars = depset(transitive = [ctx.attr.runner[0][JavaInfo].transitive_runtime_jars, coverage_runner_jars])
all_jars = [test_jars, runner_jars]

args = ctx.actions.args()
Expand All @@ -43,7 +47,7 @@ def phase_test_launcher(ctx, g):
args.add_all("--shared_classpath", shared_deps.transitive_runtime_jars, map_each = _test_launcher_short_path)
elif ctx.attr.isolation == "process":
subprocess_executable = ctx.actions.declare_file("{}/subprocess".format(ctx.label.name))
subprocess_runner_jars = ctx.attr.subprocess_runner[JavaInfo].transitive_runtime_jars
subprocess_runner_jars = ctx.attr.subprocess_runner[0][JavaInfo].transitive_runtime_jars
all_jars.append(subprocess_runner_jars)
files += _write_launcher(
ctx,
Expand Down
6 changes: 5 additions & 1 deletion rules/private/phases/phase_zinc_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def phase_zinc_compile(ctx, g):
javacopts = [
ctx.expand_location(option, ctx.attr.data)
for option in ctx.attr.javacopts + java_common.default_javac_opts(
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain),
# See https://bazel.build/extending/config#accessing-attributes-with-transitions:
# "When attaching a transition to an outgoing edge (regardless of whether the transition
# is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already.
# The order of elements in this list is unspecified."
java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]),
)
]

Expand Down
20 changes: 16 additions & 4 deletions rules/register_toolchain.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ load(
"phase_zinc_compile",
"phase_zinc_depscheck",
)
load("@rules_scala_annex_scala_toolchain//:default.bzl", "default_scala_toolchain_name")

def _bootstrap_configuration_impl(ctx):
return [
Expand Down Expand Up @@ -205,21 +206,32 @@ def _make_register_toolchain(configuration_rule):
register_bootstrap_toolchain = _make_register_toolchain(_bootstrap_configuration)
register_zinc_toolchain = _make_register_toolchain(_zinc_configuration)

def _scala_toolchain_transition_impl(_, attr):
def _scala_toolchain_incoming_transition_impl(settings, attr):
if attr.scala_toolchain_name == "":
return {}

return {
"@rules_scala_annex_scala_toolchain//:scala-toolchain": attr.scala_toolchain_name,
}

scala_toolchain_transition = transition(
implementation = _scala_toolchain_transition_impl,
scala_toolchain_incoming_transition = transition(
implementation = _scala_toolchain_incoming_transition_impl,
inputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
)

def _scala_toolchain_outgoing_transition_impl(_1, _2):
return {
"@rules_scala_annex_scala_toolchain//:scala-toolchain": default_scala_toolchain_name,
}

scala_toolchain_outgoing_transition = transition(
implementation = _scala_toolchain_outgoing_transition_impl,
inputs = [],
outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"],
)

scala_toolchain_transition_attributes = {
scala_toolchain_attributes = {
"scala_toolchain_name": attr.string(
doc = "The name of the Scala toolchain to use for this target (as provided to `register_*_toolchain`)",
),
Expand Down
83 changes: 59 additions & 24 deletions rules/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ load(":jvm.bzl", _labeled_jars = "labeled_jars")
load(":providers.bzl", _ScalaRulePhase = "ScalaRulePhase")
load(
":register_toolchain.bzl",
_scala_toolchain_transition = "scala_toolchain_transition",
_scala_toolchain_transition_attributes = "scala_toolchain_transition_attributes",
_scala_toolchain_incoming_transition = "scala_toolchain_incoming_transition",
_scala_toolchain_outgoing_transition = "scala_toolchain_outgoing_transition",
_scala_toolchain_attributes = "scala_toolchain_attributes",
)

_compile_private_attributes = {
"_java_toolchain": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = Label("@bazel_tools//tools/jdk:current_java_toolchain"),
),
"_host_javabase": attr.label(
Expand Down Expand Up @@ -75,6 +77,7 @@ _compile_private_attributes = {

_compile_attributes = {
"srcs": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The source Scala and Java files (and `-sources.jar` `.srcjar` `-src.jar` files of those).",
allow_files = [
".scala",
Expand All @@ -86,10 +89,12 @@ _compile_attributes = {
flags = ["DIRECT_COMPILE_TIME_INPUT"],
),
"data": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The additional runtime files needed by this library.",
allow_files = True,
),
"deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
aspects = [
_labeled_jars,
_coverage_replacements_provider.aspect,
Expand All @@ -98,21 +103,25 @@ _compile_attributes = {
providers = [JavaInfo],
),
"deps_used_whitelist": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM library dependencies to always consider used for `scala_deps_used` checks.",
providers = [JavaInfo],
),
"deps_unused_whitelist": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM library dependencies to always consider unused for `scala_deps_direct` checks.",
providers = [JavaInfo],
),
"runtime_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM runtime-only library dependencies.",
providers = [JavaInfo],
),
"javacopts": attr.string_list(
doc = "The Javac options.",
),
"plugins": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The Scalac plugins.",
providers = [JavaInfo],
),
Expand All @@ -121,10 +130,12 @@ _compile_attributes = {
),
"resources": attr.label_list(
allow_files = True,
cfg = _scala_toolchain_outgoing_transition,
doc = "The files to include as classpath resources.",
),
"resource_jars": attr.label_list(
allow_files = [".jar"],
cfg = _scala_toolchain_outgoing_transition,
doc = "The JARs to merge into the output JAR.",
),
"scalacopts": attr.string_list(
Expand All @@ -137,6 +148,7 @@ _library_attributes = {
aspects = [
_coverage_replacements_provider.aspect,
],
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM libraries to add as dependencies to any libraries dependent on this one.",
providers = [JavaInfo],
),
Expand All @@ -155,17 +167,20 @@ _runtime_attributes = {
doc = "The JVM runtime flags.",
),
"runtime_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The JVM runtime-only library dependencies.",
providers = [JavaInfo],
),
}

_runtime_private_attributes = {
"_target_jdk": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = Label("@bazel_tools//tools/jdk:current_java_runtime"),
providers = [java_common.JavaRuntimeInfo],
),
"_java_stub_template": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = Label("@anx_java_stub_template//file"),
allow_single_file = True,
),
Expand Down Expand Up @@ -235,11 +250,11 @@ def make_scala_library(*extras):
_compile_attributes,
_compile_private_attributes,
_library_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
_extras_attributes(extras),
*[extra["attrs"] for extra in extras]
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = "Compiles a Scala JVM library.",
implementation = _scala_library_implementation,
outputs = _dicts.add(
Expand All @@ -263,7 +278,7 @@ def make_scala_binary(*extras):
_compile_private_attributes,
_runtime_attributes,
_runtime_private_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
{
"main_class": attr.string(
doc = "The main class. If not provided, it will be inferred by its type signature.",
Expand All @@ -272,7 +287,7 @@ def make_scala_binary(*extras):
_extras_attributes(extras),
*[extra["attrs"] for extra in extras]
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Compiles and links a Scala JVM executable.
Expand Down Expand Up @@ -309,7 +324,7 @@ def make_scala_test(*extras):
_compile_private_attributes,
_runtime_attributes,
_runtime_private_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
_testing_private_attributes,
{
"isolation": attr.string(
Expand All @@ -323,6 +338,7 @@ def make_scala_test(*extras):
),
"scalacopts": attr.string_list(),
"shared_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "If isolation is \"classloader\", the list of deps to keep loaded between tests",
providers = [JavaInfo],
),
Expand All @@ -336,13 +352,19 @@ def make_scala_test(*extras):
"com.novocode.junit.JUnitFramework",
],
),
"runner": attr.label(default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test"),
"subprocess_runner": attr.label(default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess"),
"runner": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test",
),
"subprocess_runner": attr.label(
cfg = _scala_toolchain_outgoing_transition,
default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess"
),
},
_extras_attributes(extras),
*[extra["attrs"] for extra in extras]
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Compiles and links a collection of Scala tests.
Expand Down Expand Up @@ -387,13 +409,17 @@ _scala_repl_private_attributes = _dicts.add(
scala_repl = rule(
attrs = _dicts.add(
_scala_repl_private_attributes,
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
{
"data": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
doc = "The additional runtime files needed by this REPL.",
allow_files = True,
),
"deps": attr.label_list(providers = [JavaInfo]),
"deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
providers = [JavaInfo],
),
"jvm_flags": attr.string_list(
doc = "The JVM runtime flags.",
),
Expand All @@ -402,7 +428,7 @@ scala_repl = rule(
),
},
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Launches a REPL with all given dependencies available.
Expand Down Expand Up @@ -442,23 +468,32 @@ Use this only for libraries with macros. Otherwise, use `java_import`.

scaladoc = rule(
attrs = _dicts.add(
_scala_toolchain_transition_attributes,
_scala_toolchain_attributes,
_scaladoc_private_attributes,
{
"compiler_deps": attr.label_list(providers = [JavaInfo]),
"deps": attr.label_list(providers = [JavaInfo]),
"srcs": attr.label_list(allow_files = [
".java",
".scala",
".srcjar",
"-sources.jar",
"-src.jar",
]),
"compiler_deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
providers = [JavaInfo],
),
"deps": attr.label_list(
cfg = _scala_toolchain_outgoing_transition,
providers = [JavaInfo],
),
"srcs": attr.label_list(
allow_files = [
".java",
".scala",
".srcjar",
"-sources.jar",
"-src.jar",
],
cfg = _scala_toolchain_outgoing_transition,
),
"scalacopts": attr.string_list(),
"title": attr.string(),
},
),
cfg = _scala_toolchain_transition,
cfg = _scala_toolchain_incoming_transition,
doc = """
Generates Scaladocs.
""",
Expand Down
Loading

0 comments on commit 7b7115a

Please sign in to comment.