From 7b7115ad9179d8a92fd436d8afa5ff0af22b4b2c Mon Sep 17 00:00:00 2001 From: Jaden Peterson Date: Tue, 24 Sep 2024 18:12:57 -0400 Subject: [PATCH] Use an outgoing transition to ensure Scala targets' dependencies are built once --- rules/common/private/utils.bzl | 6 +- .../private/phases/phase_binary_launcher.bzl | 7 +- rules/private/phases/phase_javainfo.bzl | 10 ++- rules/private/phases/phase_test_launcher.bzl | 10 ++- rules/private/phases/phase_zinc_compile.bzl | 6 +- rules/register_toolchain.bzl | 20 ++++- rules/scala.bzl | 83 +++++++++++++------ rules/scala/private/repl.bzl | 7 +- rules/scala/workspace.bzl | 12 ++- tests/scala/toolchain/BUILD | 23 +++++ tests/scala/toolchain/Child.scala | 3 + tests/scala/toolchain/Parent.scala | 5 ++ tests/scala/toolchain/test | 18 ++++ 13 files changed, 171 insertions(+), 39 deletions(-) create mode 100644 tests/scala/toolchain/BUILD create mode 100644 tests/scala/toolchain/Child.scala create mode 100644 tests/scala/toolchain/Parent.scala create mode 100755 tests/scala/toolchain/test diff --git a/rules/common/private/utils.bzl b/rules/common/private/utils.bzl index c755d588..7efa8480 100644 --- a/rules/common/private/utils.bzl +++ b/rules/common/private/utils.bzl @@ -65,7 +65,11 @@ def write_launcher( # runfiles_enabled = ctx.configuration.runfiles_enabled() runfiles_enabled = False - java_runtime_info = ctx.attr._target_jdk[java_common.JavaRuntimeInfo] + # See https://bazel.build/extending/config#accessing-attributes-with-transitions: + # "When attaching a transition to an outgoing edge (regardless of whether the transition is a + # 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of + # elements in this list is unspecified." + java_runtime_info = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo] java_executable = java_runtime_info.java_executable_runfiles_path if not paths.is_absolute(java_executable): java_executable = workspace_name + "/" + java_executable diff --git a/rules/private/phases/phase_binary_launcher.bzl b/rules/private/phases/phase_binary_launcher.bzl index 5d969cea..2c9116f0 100644 --- a/rules/private/phases/phase_binary_launcher.bzl +++ b/rules/private/phases/phase_binary_launcher.bzl @@ -35,7 +35,12 @@ def phase_binary_launcher(ctx, g): files = inputs + files, transitive_files = depset( order = "default", - transitive = [ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars], + + # See https://bazel.build/extending/config#accessing-attributes-with-transitions: + # "When attaching a transition to an outgoing edge (regardless of whether the + # transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it + # isn't already. The order of elements in this list is unspecified." + transitive = [ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files, g.javainfo.java_info.transitive_runtime_jars], ), collect_default = True, ), diff --git a/rules/private/phases/phase_javainfo.bzl b/rules/private/phases/phase_javainfo.bzl index 6362b50b..f3dce17c 100644 --- a/rules/private/phases/phase_javainfo.bzl +++ b/rules/private/phases/phase_javainfo.bzl @@ -3,6 +3,7 @@ load( "find_java_runtime_toolchain", "find_java_toolchain", ) +load("@rules_java//java/common:java_common.bzl", "java_common") load( "@rules_scala_annex//rules:providers.bzl", _ScalaConfiguration = "ScalaConfiguration", @@ -34,7 +35,12 @@ def phase_javainfo(ctx, g): ctx.actions, jar = ctx.outputs.jar, target_label = ctx.label, - java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain), + + # See https://bazel.build/extending/config#accessing-attributes-with-transitions: + # "When attaching a transition to an outgoing edge (regardless of whether the + # transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it + # isn't already. The order of elements in this list is unspecified." + java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]), ) source_jar_name = ctx.outputs.jar.basename.replace(".jar", "-src.jar") @@ -47,7 +53,7 @@ def phase_javainfo(ctx, g): ctx.actions, output_source_jar = output_source_jar, sources = ctx.files.srcs, - java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain), + java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]), ) java_info = JavaInfo( diff --git a/rules/private/phases/phase_test_launcher.bzl b/rules/private/phases/phase_test_launcher.bzl index 11854c69..4694572c 100644 --- a/rules/private/phases/phase_test_launcher.bzl +++ b/rules/private/phases/phase_test_launcher.bzl @@ -16,7 +16,11 @@ load( # def phase_test_launcher(ctx, g): - files = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store] + # See https://bazel.build/extending/config#accessing-attributes-with-transitions: + # "When attaching a transition to an outgoing edge (regardless of whether the transition is a + # 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. The order of + # elements in this list is unspecified." + files = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files.to_list() + [g.compile.zinc_info.analysis_store] coverage_replacements = {} coverage_runner_jars = depset(direct = []) @@ -31,7 +35,7 @@ def phase_test_launcher(ctx, g): coverage_replacements[jar] if jar in coverage_replacements else jar for jar in g.javainfo.java_info.transitive_runtime_jars.to_list() ]) - runner_jars = depset(transitive = [ctx.attr.runner[JavaInfo].transitive_runtime_jars, coverage_runner_jars]) + runner_jars = depset(transitive = [ctx.attr.runner[0][JavaInfo].transitive_runtime_jars, coverage_runner_jars]) all_jars = [test_jars, runner_jars] args = ctx.actions.args() @@ -43,7 +47,7 @@ def phase_test_launcher(ctx, g): args.add_all("--shared_classpath", shared_deps.transitive_runtime_jars, map_each = _test_launcher_short_path) elif ctx.attr.isolation == "process": subprocess_executable = ctx.actions.declare_file("{}/subprocess".format(ctx.label.name)) - subprocess_runner_jars = ctx.attr.subprocess_runner[JavaInfo].transitive_runtime_jars + subprocess_runner_jars = ctx.attr.subprocess_runner[0][JavaInfo].transitive_runtime_jars all_jars.append(subprocess_runner_jars) files += _write_launcher( ctx, diff --git a/rules/private/phases/phase_zinc_compile.bzl b/rules/private/phases/phase_zinc_compile.bzl index 59976f34..58246420 100644 --- a/rules/private/phases/phase_zinc_compile.bzl +++ b/rules/private/phases/phase_zinc_compile.bzl @@ -30,7 +30,11 @@ def phase_zinc_compile(ctx, g): javacopts = [ ctx.expand_location(option, ctx.attr.data) for option in ctx.attr.javacopts + java_common.default_javac_opts( - java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain), + # See https://bazel.build/extending/config#accessing-attributes-with-transitions: + # "When attaching a transition to an outgoing edge (regardless of whether the transition + # is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it isn't already. + # The order of elements in this list is unspecified." + java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain[0]), ) ] diff --git a/rules/register_toolchain.bzl b/rules/register_toolchain.bzl index 9f8112dd..0b02bb38 100644 --- a/rules/register_toolchain.bzl +++ b/rules/register_toolchain.bzl @@ -14,6 +14,7 @@ load( "phase_zinc_compile", "phase_zinc_depscheck", ) +load("@rules_scala_annex_scala_toolchain//:default.bzl", "default_scala_toolchain_name") def _bootstrap_configuration_impl(ctx): return [ @@ -205,7 +206,7 @@ def _make_register_toolchain(configuration_rule): register_bootstrap_toolchain = _make_register_toolchain(_bootstrap_configuration) register_zinc_toolchain = _make_register_toolchain(_zinc_configuration) -def _scala_toolchain_transition_impl(_, attr): +def _scala_toolchain_incoming_transition_impl(settings, attr): if attr.scala_toolchain_name == "": return {} @@ -213,13 +214,24 @@ def _scala_toolchain_transition_impl(_, attr): "@rules_scala_annex_scala_toolchain//:scala-toolchain": attr.scala_toolchain_name, } -scala_toolchain_transition = transition( - implementation = _scala_toolchain_transition_impl, +scala_toolchain_incoming_transition = transition( + implementation = _scala_toolchain_incoming_transition_impl, + inputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"], + outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"], +) + +def _scala_toolchain_outgoing_transition_impl(_1, _2): + return { + "@rules_scala_annex_scala_toolchain//:scala-toolchain": default_scala_toolchain_name, + } + +scala_toolchain_outgoing_transition = transition( + implementation = _scala_toolchain_outgoing_transition_impl, inputs = [], outputs = ["@rules_scala_annex_scala_toolchain//:scala-toolchain"], ) -scala_toolchain_transition_attributes = { +scala_toolchain_attributes = { "scala_toolchain_name": attr.string( doc = "The name of the Scala toolchain to use for this target (as provided to `register_*_toolchain`)", ), diff --git a/rules/scala.bzl b/rules/scala.bzl index 0fa4a5f6..ea998d53 100644 --- a/rules/scala.bzl +++ b/rules/scala.bzl @@ -41,12 +41,14 @@ load(":jvm.bzl", _labeled_jars = "labeled_jars") load(":providers.bzl", _ScalaRulePhase = "ScalaRulePhase") load( ":register_toolchain.bzl", - _scala_toolchain_transition = "scala_toolchain_transition", - _scala_toolchain_transition_attributes = "scala_toolchain_transition_attributes", + _scala_toolchain_incoming_transition = "scala_toolchain_incoming_transition", + _scala_toolchain_outgoing_transition = "scala_toolchain_outgoing_transition", + _scala_toolchain_attributes = "scala_toolchain_attributes", ) _compile_private_attributes = { "_java_toolchain": attr.label( + cfg = _scala_toolchain_outgoing_transition, default = Label("@bazel_tools//tools/jdk:current_java_toolchain"), ), "_host_javabase": attr.label( @@ -75,6 +77,7 @@ _compile_private_attributes = { _compile_attributes = { "srcs": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The source Scala and Java files (and `-sources.jar` `.srcjar` `-src.jar` files of those).", allow_files = [ ".scala", @@ -86,10 +89,12 @@ _compile_attributes = { flags = ["DIRECT_COMPILE_TIME_INPUT"], ), "data": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The additional runtime files needed by this library.", allow_files = True, ), "deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, aspects = [ _labeled_jars, _coverage_replacements_provider.aspect, @@ -98,14 +103,17 @@ _compile_attributes = { providers = [JavaInfo], ), "deps_used_whitelist": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The JVM library dependencies to always consider used for `scala_deps_used` checks.", providers = [JavaInfo], ), "deps_unused_whitelist": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The JVM library dependencies to always consider unused for `scala_deps_direct` checks.", providers = [JavaInfo], ), "runtime_deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The JVM runtime-only library dependencies.", providers = [JavaInfo], ), @@ -113,6 +121,7 @@ _compile_attributes = { doc = "The Javac options.", ), "plugins": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The Scalac plugins.", providers = [JavaInfo], ), @@ -121,10 +130,12 @@ _compile_attributes = { ), "resources": attr.label_list( allow_files = True, + cfg = _scala_toolchain_outgoing_transition, doc = "The files to include as classpath resources.", ), "resource_jars": attr.label_list( allow_files = [".jar"], + cfg = _scala_toolchain_outgoing_transition, doc = "The JARs to merge into the output JAR.", ), "scalacopts": attr.string_list( @@ -137,6 +148,7 @@ _library_attributes = { aspects = [ _coverage_replacements_provider.aspect, ], + cfg = _scala_toolchain_outgoing_transition, doc = "The JVM libraries to add as dependencies to any libraries dependent on this one.", providers = [JavaInfo], ), @@ -155,6 +167,7 @@ _runtime_attributes = { doc = "The JVM runtime flags.", ), "runtime_deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The JVM runtime-only library dependencies.", providers = [JavaInfo], ), @@ -162,10 +175,12 @@ _runtime_attributes = { _runtime_private_attributes = { "_target_jdk": attr.label( + cfg = _scala_toolchain_outgoing_transition, default = Label("@bazel_tools//tools/jdk:current_java_runtime"), providers = [java_common.JavaRuntimeInfo], ), "_java_stub_template": attr.label( + cfg = _scala_toolchain_outgoing_transition, default = Label("@anx_java_stub_template//file"), allow_single_file = True, ), @@ -235,11 +250,11 @@ def make_scala_library(*extras): _compile_attributes, _compile_private_attributes, _library_attributes, - _scala_toolchain_transition_attributes, + _scala_toolchain_attributes, _extras_attributes(extras), *[extra["attrs"] for extra in extras] ), - cfg = _scala_toolchain_transition, + cfg = _scala_toolchain_incoming_transition, doc = "Compiles a Scala JVM library.", implementation = _scala_library_implementation, outputs = _dicts.add( @@ -263,7 +278,7 @@ def make_scala_binary(*extras): _compile_private_attributes, _runtime_attributes, _runtime_private_attributes, - _scala_toolchain_transition_attributes, + _scala_toolchain_attributes, { "main_class": attr.string( doc = "The main class. If not provided, it will be inferred by its type signature.", @@ -272,7 +287,7 @@ def make_scala_binary(*extras): _extras_attributes(extras), *[extra["attrs"] for extra in extras] ), - cfg = _scala_toolchain_transition, + cfg = _scala_toolchain_incoming_transition, doc = """ Compiles and links a Scala JVM executable. @@ -309,7 +324,7 @@ def make_scala_test(*extras): _compile_private_attributes, _runtime_attributes, _runtime_private_attributes, - _scala_toolchain_transition_attributes, + _scala_toolchain_attributes, _testing_private_attributes, { "isolation": attr.string( @@ -323,6 +338,7 @@ def make_scala_test(*extras): ), "scalacopts": attr.string_list(), "shared_deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "If isolation is \"classloader\", the list of deps to keep loaded between tests", providers = [JavaInfo], ), @@ -336,13 +352,19 @@ def make_scala_test(*extras): "com.novocode.junit.JUnitFramework", ], ), - "runner": attr.label(default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test"), - "subprocess_runner": attr.label(default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess"), + "runner": attr.label( + cfg = _scala_toolchain_outgoing_transition, + default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test", + ), + "subprocess_runner": attr.label( + cfg = _scala_toolchain_outgoing_transition, + default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess" + ), }, _extras_attributes(extras), *[extra["attrs"] for extra in extras] ), - cfg = _scala_toolchain_transition, + cfg = _scala_toolchain_incoming_transition, doc = """ Compiles and links a collection of Scala tests. @@ -387,13 +409,17 @@ _scala_repl_private_attributes = _dicts.add( scala_repl = rule( attrs = _dicts.add( _scala_repl_private_attributes, - _scala_toolchain_transition_attributes, + _scala_toolchain_attributes, { "data": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, doc = "The additional runtime files needed by this REPL.", allow_files = True, ), - "deps": attr.label_list(providers = [JavaInfo]), + "deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, + providers = [JavaInfo], + ), "jvm_flags": attr.string_list( doc = "The JVM runtime flags.", ), @@ -402,7 +428,7 @@ scala_repl = rule( ), }, ), - cfg = _scala_toolchain_transition, + cfg = _scala_toolchain_incoming_transition, doc = """ Launches a REPL with all given dependencies available. @@ -442,23 +468,32 @@ Use this only for libraries with macros. Otherwise, use `java_import`. scaladoc = rule( attrs = _dicts.add( - _scala_toolchain_transition_attributes, + _scala_toolchain_attributes, _scaladoc_private_attributes, { - "compiler_deps": attr.label_list(providers = [JavaInfo]), - "deps": attr.label_list(providers = [JavaInfo]), - "srcs": attr.label_list(allow_files = [ - ".java", - ".scala", - ".srcjar", - "-sources.jar", - "-src.jar", - ]), + "compiler_deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, + providers = [JavaInfo], + ), + "deps": attr.label_list( + cfg = _scala_toolchain_outgoing_transition, + providers = [JavaInfo], + ), + "srcs": attr.label_list( + allow_files = [ + ".java", + ".scala", + ".srcjar", + "-sources.jar", + "-src.jar", + ], + cfg = _scala_toolchain_outgoing_transition, + ), "scalacopts": attr.string_list(), "title": attr.string(), }, ), - cfg = _scala_toolchain_transition, + cfg = _scala_toolchain_incoming_transition, doc = """ Generates Scaladocs. """, diff --git a/rules/scala/private/repl.bzl b/rules/scala/private/repl.bzl index f03431c5..e571b80c 100644 --- a/rules/scala/private/repl.bzl +++ b/rules/scala/private/repl.bzl @@ -49,7 +49,12 @@ def scala_repl_implementation(ctx): runfiles = ctx.runfiles( collect_default = True, collect_data = True, - files = ctx.attr._target_jdk[java_common.JavaRuntimeInfo].files.to_list(), + + # See https://bazel.build/extending/config#accessing-attributes-with-transitions: + # "When attaching a transition to an outgoing edge (regardless of whether the + # transition is a 1:1 or 1:2+ transition), `ctx.attr` is forced to be a list if it + # isn't already. The order of elements in this list is unspecified." + files = ctx.attr._target_jdk[0][java_common.JavaRuntimeInfo].files.to_list(), transitive_files = files, ), ), diff --git a/rules/scala/workspace.bzl b/rules/scala/workspace.bzl index a9a30550..66afc8a8 100644 --- a/rules/scala/workspace.bzl +++ b/rules/scala/workspace.bzl @@ -83,14 +83,22 @@ def _toolchain_configuration_repository_impl(repository_ctx): repository_ctx.file( "BUILD", """\ +load(":default.bzl", "default_scala_toolchain_name") load("@bazel_skylib//rules:common_settings.bzl", "string_setting") string_setting( name = "scala-toolchain", - build_setting_default = "{}", + build_setting_default = default_scala_toolchain_name, visibility = ["//visibility:public"], ) -""".format(repository_ctx.attr.default_scala_toolchain_name), +""", + ) + + repository_ctx.file( + "default.bzl", + "default_scala_toolchain_name = \"{}\"\n".format( + repository_ctx.attr.default_scala_toolchain_name, + ), ) _toolchain_configuration_repository = repository_rule( diff --git a/tests/scala/toolchain/BUILD b/tests/scala/toolchain/BUILD new file mode 100644 index 00000000..e07eefc9 --- /dev/null +++ b/tests/scala/toolchain/BUILD @@ -0,0 +1,23 @@ +load("@rules_scala_annex//rules:scala.bzl", "scala_library") + +scala_library( + name = "child", + srcs = ["Child.scala"], + tags = ["manual"], +) + +scala_library( + name = "parent1", + srcs = ["Parent.scala"], + scala_toolchain_name = "test_zinc_2_13", + tags = ["manual"], + deps = [":child"], +) + +scala_library( + name = "parent2", + srcs = ["Parent.scala"], + scala_toolchain_name = "test_zinc_3", + tags = ["manual"], + deps = [":child"], +) diff --git a/tests/scala/toolchain/Child.scala b/tests/scala/toolchain/Child.scala new file mode 100644 index 00000000..2ea5779c --- /dev/null +++ b/tests/scala/toolchain/Child.scala @@ -0,0 +1,3 @@ +object Child { + val greeting: String = "Hello, world!" +} diff --git a/tests/scala/toolchain/Parent.scala b/tests/scala/toolchain/Parent.scala new file mode 100644 index 00000000..6791469d --- /dev/null +++ b/tests/scala/toolchain/Parent.scala @@ -0,0 +1,5 @@ +object Parent { + def main(arguments: Array[String]): Unit = { + println(Child.greeting) + } +} diff --git a/tests/scala/toolchain/test b/tests/scala/toolchain/test new file mode 100755 index 00000000..5d38464d --- /dev/null +++ b/tests/scala/toolchain/test @@ -0,0 +1,18 @@ +#!/bin/bash -e +. "$(dirname "$0")"/../../common.sh + + +bazel_cquery_output="$( + bazel cquery 'deps(:parent1 + :parent2, 1)' --output graph --nograph:factored |& + sed -nr 's/^ *"\/\/scala\/toolchain:child \(([0-9a-f]*)\)"$/\1/p' +)" + +if [ "$(echo "$bazel_cquery_output" | wc -l)" -ne 1 ]; then + echo 'Expected //scala/toolchain:child to be built only once' + exit 1 +fi + +if echo "$bazel_cquery_output" | grep '@rules_scala_annex_scala_toolchain//:scala-toolchain'; then + echo 'Expected //scala/toolchain:child to be built with the default configuration' + exit 1 +fi