From 43bcfb8cf38530402a5ceee16b655f37e38c2880 Mon Sep 17 00:00:00 2001 From: Andreas Stenius Date: Thu, 7 Mar 2024 10:27:23 +0100 Subject: [PATCH] Core: Fix `__defaults__` to also apply for generated file targets. --- .gitignore | 1 + .../engine/internals/build_files_test.py | 51 +++++- src/python/pants/engine/internals/graph.py | 169 +++++++++--------- 3 files changed, 139 insertions(+), 82 deletions(-) diff --git a/.gitignore b/.gitignore index 8e008fc27f3..2905057168b 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,4 @@ GTAGS /.pants.rc /.venv .tool-versions +TAGS diff --git a/src/python/pants/engine/internals/build_files_test.py b/src/python/pants/engine/internals/build_files_test.py index ba39938c68c..32793f975a1 100644 --- a/src/python/pants/engine/internals/build_files_test.py +++ b/src/python/pants/engine/internals/build_files_test.py @@ -40,10 +40,13 @@ from pants.engine.target import ( Dependencies, MultipleSourcesField, + OverridesField, RegisteredTargetTypes, + SingleSourceField, StringField, Tags, Target, + TargetFilesGenerator, ) from pants.engine.unions import UnionMembership from pants.init.bootstrap_scheduler import BootstrapStatus @@ -357,6 +360,23 @@ class MockTgt(Target): core_fields = (MockDepsField, MockMultipleSourcesField, Tags, ResolveField) +class MockSingleSourceField(SingleSourceField): + pass + + +class MockGeneratedTarget(Target): + alias = "generated" + core_fields = (MockDepsField, Tags, MockSingleSourceField, ResolveField) + + +class MockTargetGenerator(TargetFilesGenerator): + alias = "generator" + core_fields = (MockMultipleSourcesField, OverridesField) + generated_target_cls = MockGeneratedTarget + copied_fields = () + moved_fields = (MockDepsField, Tags, ResolveField) + + def test_resolve_address() -> None: rule_runner = RuleRunner( rules=[QueryRule(Address, [AddressInput]), QueryRule(MaybeAddress, [AddressInput])] @@ -407,7 +427,7 @@ def assert_is_expected(address_input: AddressInput, expected: Address) -> None: def target_adaptor_rule_runner() -> RuleRunner: return RuleRunner( rules=[QueryRule(TargetAdaptor, (TargetAdaptorRequest,))], - target_types=[MockTgt], + target_types=[MockTgt, MockGeneratedTarget, MockTargetGenerator], objects={"parametrize": Parametrize}, ) @@ -500,6 +520,35 @@ def test_target_adaptor_defaults_applied(target_adaptor_rule_runner: RuleRunner) assert target_adaptor.kwargs["tags"] == ["24"] +def test_generated_target_defaults(target_adaptor_rule_runner: RuleRunner) -> None: + target_adaptor_rule_runner.write_files( + { + "BUILD": dedent( + """\ + __defaults__({generated: dict(resolve="mock")}, all=dict(tags=["24"])) + generated(name="explicit", tags=["42"], source="e.txt") + generator(name='gen', sources=["g*.txt"]) + """ + ), + "e.txt": "", + "g1.txt": "", + "g2.txt": "", + } + ) + + explicit_target = target_adaptor_rule_runner.get_target(Address("", target_name="explicit")) + assert explicit_target.address.target_name == "explicit" + assert explicit_target.get(ResolveField).value == "mock" + assert explicit_target.get(Tags).value == ("42",) + + implicit_target = target_adaptor_rule_runner.get_target( + Address("", target_name="gen", relative_file_path="g1.txt") + ) + assert str(implicit_target.address) == "//g1.txt:gen" + assert implicit_target.get(ResolveField).value == "mock" + assert implicit_target.get(Tags).value == ("24",) + + def test_inherit_defaults(target_adaptor_rule_runner: RuleRunner) -> None: target_adaptor_rule_runner.write_files( { diff --git a/src/python/pants/engine/internals/graph.py b/src/python/pants/engine/internals/graph.py index 0b475c7dff8..dda87248bfb 100644 --- a/src/python/pants/engine/internals/graph.py +++ b/src/python/pants/engine/internals/graph.py @@ -27,7 +27,8 @@ from pants.engine.environment import ChosenLocalEnvironmentName, EnvironmentName from pants.engine.fs import EMPTY_SNAPSHOT, GlobMatchErrorBehavior, PathGlobs, Paths, Snapshot from pants.engine.internals import native_engine -from pants.engine.internals.mapper import AddressFamilies +from pants.engine.internals.build_files import AddressFamilyDir +from pants.engine.internals.mapper import AddressFamilies, AddressFamily from pants.engine.internals.native_engine import AddressParseException from pants.engine.internals.parametrize import Parametrize, _TargetParametrization from pants.engine.internals.parametrize import ( # noqa: F401 @@ -255,6 +256,91 @@ async def resolve_all_generator_target_requests( ) +async def _parametrized_target_generators_with_templates( + address: Address, + target_adaptor: TargetAdaptor, + target_type: type[TargetGenerator], + generator_fields: dict[str, Any], + union_membership: UnionMembership, +) -> list[tuple[TargetGenerator, dict[str, Any]]]: + # Pre-load field values from defaults for the target type being generated. + if hasattr(target_type, "generated_target_cls"): + family = await Get(AddressFamily, AddressFamilyDir(address.spec_path)) + template_fields = dict(family.defaults.get(target_type.generated_target_cls.alias, {})) + else: + template_fields = {} + + # Split out the `propagated_fields` before construction. + copied_fields = ( + *target_type.copied_fields, + *target_type._find_copied_plugin_fields(union_membership), + ) + moved_fields = ( + *target_type.moved_fields, + *target_type._find_moved_plugin_fields(union_membership), + ) + for field_type in copied_fields: + for alias in (field_type.deprecated_alias, field_type.alias): + if alias is None: + continue + # Any deprecated field use will be checked on the generator target. + field_value = generator_fields.get(alias, None) + if field_value is not None: + template_fields[alias] = field_value + for field_type in moved_fields: + # We must check for deprecated field usage here before passing the value to the generator. + if field_type.deprecated_alias is not None: + field_value = generator_fields.pop(field_type.deprecated_alias, None) + if field_value is not None: + warn_deprecated_field_type(field_type) + template_fields[field_type.deprecated_alias] = field_value + field_value = generator_fields.pop(field_type.alias, None) + if field_value is not None: + template_fields[field_type.alias] = field_value + + # Move parametrize groups over to `template_fields` in order to expand them. + parametrize_group_field_names = [ + name + for name, field in generator_fields.items() + if isinstance(field, Parametrize) and field.is_group + ] + for field_name in parametrize_group_field_names: + template_fields[field_name] = generator_fields.pop(field_name) + + field_type_aliases = target_type._get_field_aliases_to_field_types( + target_type.class_field_types(union_membership) + ).keys() + generator_fields_parametrized = { + name + for name, field in generator_fields.items() + if isinstance(field, Parametrize) and name in field_type_aliases + } + if generator_fields_parametrized: + noun = pluralize(len(generator_fields_parametrized), "field", include_count=False) + generator_fields_parametrized_text = ", ".join( + repr(f) for f in generator_fields_parametrized + ) + raise InvalidFieldException( + f"Only fields which will be moved to generated targets may be parametrized, " + f"so target generator {address} (with type {target_type.alias}) cannot " + f"parametrize the {generator_fields_parametrized_text} {noun}." + ) + return [ + ( + _create_target( + address, + target_type, + target_adaptor, + generator_fields, + union_membership, + name_explicitly_set=target_adaptor.name is not None, + ), + template, + ) + for address, template in Parametrize.expand(address, template_fields) + ] + + async def _target_generator_overrides( target_generator: TargetGenerator, unmatched_build_file_globs: UnmatchedBuildFileGlobs ) -> dict[str, dict[str, Any]]: @@ -296,7 +382,7 @@ async def resolve_generator_target_requests( if not generate_request: return ResolvedTargetGeneratorRequests() generator_fields = dict(target_adaptor.kwargs) - generators = _parametrized_target_generators_with_templates( + generators = await _parametrized_target_generators_with_templates( req.address, target_adaptor, target_type, @@ -428,85 +514,6 @@ def _target_parametrizations( return _TargetParametrization(target, FrozenDict()) -def _parametrized_target_generators_with_templates( - address: Address, - target_adaptor: TargetAdaptor, - target_type: type[TargetGenerator], - generator_fields: dict[str, Any], - union_membership: UnionMembership, -) -> list[tuple[TargetGenerator, dict[str, Any]]]: - # Split out the `propagated_fields` before construction. - template_fields = {} - copied_fields = ( - *target_type.copied_fields, - *target_type._find_copied_plugin_fields(union_membership), - ) - moved_fields = ( - *target_type.moved_fields, - *target_type._find_moved_plugin_fields(union_membership), - ) - for field_type in copied_fields: - for alias in (field_type.deprecated_alias, field_type.alias): - if alias is None: - continue - # Any deprecated field use will be checked on the generator target. - field_value = generator_fields.get(alias, None) - if field_value is not None: - template_fields[alias] = field_value - for field_type in moved_fields: - # We must check for deprecated field usage here before passing the value to the generator. - if field_type.deprecated_alias is not None: - field_value = generator_fields.pop(field_type.deprecated_alias, None) - if field_value is not None: - warn_deprecated_field_type(field_type) - template_fields[field_type.deprecated_alias] = field_value - field_value = generator_fields.pop(field_type.alias, None) - if field_value is not None: - template_fields[field_type.alias] = field_value - - # Move parametrize groups over to `template_fields` in order to expand them. - parametrize_group_field_names = [ - name - for name, field in generator_fields.items() - if isinstance(field, Parametrize) and field.is_group - ] - for field_name in parametrize_group_field_names: - template_fields[field_name] = generator_fields.pop(field_name) - - field_type_aliases = target_type._get_field_aliases_to_field_types( - target_type.class_field_types(union_membership) - ).keys() - generator_fields_parametrized = { - name - for name, field in generator_fields.items() - if isinstance(field, Parametrize) and name in field_type_aliases - } - if generator_fields_parametrized: - noun = pluralize(len(generator_fields_parametrized), "field", include_count=False) - generator_fields_parametrized_text = ", ".join( - repr(f) for f in generator_fields_parametrized - ) - raise InvalidFieldException( - f"Only fields which will be moved to generated targets may be parametrized, " - f"so target generator {address} (with type {target_type.alias}) cannot " - f"parametrize the {generator_fields_parametrized_text} {noun}." - ) - return [ - ( - _create_target( - address, - target_type, - target_adaptor, - generator_fields, - union_membership, - name_explicitly_set=target_adaptor.name is not None, - ), - template, - ) - for address, template in Parametrize.expand(address, template_fields) - ] - - @rule(_masked_types=[EnvironmentName]) async def resolve_target( request: WrappedTargetRequest,