Skip to content

Commit

Permalink
Core: Fix __defaults__ to also apply for generated file targets.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaos committed Mar 7, 2024
1 parent d2af94c commit 43bcfb8
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ GTAGS
/.pants.rc
/.venv
.tool-versions
TAGS
51 changes: 50 additions & 1 deletion src/python/pants/engine/internals/build_files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@
from pants.engine.target import (
Dependencies,
MultipleSourcesField,
OverridesField,
RegisteredTargetTypes,
SingleSourceField,
StringField,
Tags,
Target,
TargetFilesGenerator,
)
from pants.engine.unions import UnionMembership
from pants.init.bootstrap_scheduler import BootstrapStatus
Expand Down Expand Up @@ -357,6 +360,23 @@ class MockTgt(Target):
core_fields = (MockDepsField, MockMultipleSourcesField, Tags, ResolveField)


class MockSingleSourceField(SingleSourceField):
pass


class MockGeneratedTarget(Target):
alias = "generated"
core_fields = (MockDepsField, Tags, MockSingleSourceField, ResolveField)


class MockTargetGenerator(TargetFilesGenerator):
alias = "generator"
core_fields = (MockMultipleSourcesField, OverridesField)
generated_target_cls = MockGeneratedTarget
copied_fields = ()
moved_fields = (MockDepsField, Tags, ResolveField)


def test_resolve_address() -> None:
rule_runner = RuleRunner(
rules=[QueryRule(Address, [AddressInput]), QueryRule(MaybeAddress, [AddressInput])]
Expand Down Expand Up @@ -407,7 +427,7 @@ def assert_is_expected(address_input: AddressInput, expected: Address) -> None:
def target_adaptor_rule_runner() -> RuleRunner:
return RuleRunner(
rules=[QueryRule(TargetAdaptor, (TargetAdaptorRequest,))],
target_types=[MockTgt],
target_types=[MockTgt, MockGeneratedTarget, MockTargetGenerator],
objects={"parametrize": Parametrize},
)

Expand Down Expand Up @@ -500,6 +520,35 @@ def test_target_adaptor_defaults_applied(target_adaptor_rule_runner: RuleRunner)
assert target_adaptor.kwargs["tags"] == ["24"]


def test_generated_target_defaults(target_adaptor_rule_runner: RuleRunner) -> None:
target_adaptor_rule_runner.write_files(
{
"BUILD": dedent(
"""\
__defaults__({generated: dict(resolve="mock")}, all=dict(tags=["24"]))
generated(name="explicit", tags=["42"], source="e.txt")
generator(name='gen', sources=["g*.txt"])
"""
),
"e.txt": "",
"g1.txt": "",
"g2.txt": "",
}
)

explicit_target = target_adaptor_rule_runner.get_target(Address("", target_name="explicit"))
assert explicit_target.address.target_name == "explicit"
assert explicit_target.get(ResolveField).value == "mock"
assert explicit_target.get(Tags).value == ("42",)

implicit_target = target_adaptor_rule_runner.get_target(
Address("", target_name="gen", relative_file_path="g1.txt")
)
assert str(implicit_target.address) == "//g1.txt:gen"
assert implicit_target.get(ResolveField).value == "mock"
assert implicit_target.get(Tags).value == ("24",)


def test_inherit_defaults(target_adaptor_rule_runner: RuleRunner) -> None:
target_adaptor_rule_runner.write_files(
{
Expand Down
169 changes: 88 additions & 81 deletions src/python/pants/engine/internals/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from pants.engine.environment import ChosenLocalEnvironmentName, EnvironmentName
from pants.engine.fs import EMPTY_SNAPSHOT, GlobMatchErrorBehavior, PathGlobs, Paths, Snapshot
from pants.engine.internals import native_engine
from pants.engine.internals.mapper import AddressFamilies
from pants.engine.internals.build_files import AddressFamilyDir
from pants.engine.internals.mapper import AddressFamilies, AddressFamily
from pants.engine.internals.native_engine import AddressParseException
from pants.engine.internals.parametrize import Parametrize, _TargetParametrization
from pants.engine.internals.parametrize import ( # noqa: F401
Expand Down Expand Up @@ -255,6 +256,91 @@ async def resolve_all_generator_target_requests(
)


async def _parametrized_target_generators_with_templates(
address: Address,
target_adaptor: TargetAdaptor,
target_type: type[TargetGenerator],
generator_fields: dict[str, Any],
union_membership: UnionMembership,
) -> list[tuple[TargetGenerator, dict[str, Any]]]:
# Pre-load field values from defaults for the target type being generated.
if hasattr(target_type, "generated_target_cls"):
family = await Get(AddressFamily, AddressFamilyDir(address.spec_path))
template_fields = dict(family.defaults.get(target_type.generated_target_cls.alias, {}))
else:
template_fields = {}

# Split out the `propagated_fields` before construction.
copied_fields = (
*target_type.copied_fields,
*target_type._find_copied_plugin_fields(union_membership),
)
moved_fields = (
*target_type.moved_fields,
*target_type._find_moved_plugin_fields(union_membership),
)
for field_type in copied_fields:
for alias in (field_type.deprecated_alias, field_type.alias):
if alias is None:
continue
# Any deprecated field use will be checked on the generator target.
field_value = generator_fields.get(alias, None)
if field_value is not None:
template_fields[alias] = field_value
for field_type in moved_fields:
# We must check for deprecated field usage here before passing the value to the generator.
if field_type.deprecated_alias is not None:
field_value = generator_fields.pop(field_type.deprecated_alias, None)
if field_value is not None:
warn_deprecated_field_type(field_type)
template_fields[field_type.deprecated_alias] = field_value
field_value = generator_fields.pop(field_type.alias, None)
if field_value is not None:
template_fields[field_type.alias] = field_value

# Move parametrize groups over to `template_fields` in order to expand them.
parametrize_group_field_names = [
name
for name, field in generator_fields.items()
if isinstance(field, Parametrize) and field.is_group
]
for field_name in parametrize_group_field_names:
template_fields[field_name] = generator_fields.pop(field_name)

field_type_aliases = target_type._get_field_aliases_to_field_types(
target_type.class_field_types(union_membership)
).keys()
generator_fields_parametrized = {
name
for name, field in generator_fields.items()
if isinstance(field, Parametrize) and name in field_type_aliases
}
if generator_fields_parametrized:
noun = pluralize(len(generator_fields_parametrized), "field", include_count=False)
generator_fields_parametrized_text = ", ".join(
repr(f) for f in generator_fields_parametrized
)
raise InvalidFieldException(
f"Only fields which will be moved to generated targets may be parametrized, "
f"so target generator {address} (with type {target_type.alias}) cannot "
f"parametrize the {generator_fields_parametrized_text} {noun}."
)
return [
(
_create_target(
address,
target_type,
target_adaptor,
generator_fields,
union_membership,
name_explicitly_set=target_adaptor.name is not None,
),
template,
)
for address, template in Parametrize.expand(address, template_fields)
]


async def _target_generator_overrides(
target_generator: TargetGenerator, unmatched_build_file_globs: UnmatchedBuildFileGlobs
) -> dict[str, dict[str, Any]]:
Expand Down Expand Up @@ -296,7 +382,7 @@ async def resolve_generator_target_requests(
if not generate_request:
return ResolvedTargetGeneratorRequests()
generator_fields = dict(target_adaptor.kwargs)
generators = _parametrized_target_generators_with_templates(
generators = await _parametrized_target_generators_with_templates(
req.address,
target_adaptor,
target_type,
Expand Down Expand Up @@ -428,85 +514,6 @@ def _target_parametrizations(
return _TargetParametrization(target, FrozenDict())


def _parametrized_target_generators_with_templates(
address: Address,
target_adaptor: TargetAdaptor,
target_type: type[TargetGenerator],
generator_fields: dict[str, Any],
union_membership: UnionMembership,
) -> list[tuple[TargetGenerator, dict[str, Any]]]:
# Split out the `propagated_fields` before construction.
template_fields = {}
copied_fields = (
*target_type.copied_fields,
*target_type._find_copied_plugin_fields(union_membership),
)
moved_fields = (
*target_type.moved_fields,
*target_type._find_moved_plugin_fields(union_membership),
)
for field_type in copied_fields:
for alias in (field_type.deprecated_alias, field_type.alias):
if alias is None:
continue
# Any deprecated field use will be checked on the generator target.
field_value = generator_fields.get(alias, None)
if field_value is not None:
template_fields[alias] = field_value
for field_type in moved_fields:
# We must check for deprecated field usage here before passing the value to the generator.
if field_type.deprecated_alias is not None:
field_value = generator_fields.pop(field_type.deprecated_alias, None)
if field_value is not None:
warn_deprecated_field_type(field_type)
template_fields[field_type.deprecated_alias] = field_value
field_value = generator_fields.pop(field_type.alias, None)
if field_value is not None:
template_fields[field_type.alias] = field_value

# Move parametrize groups over to `template_fields` in order to expand them.
parametrize_group_field_names = [
name
for name, field in generator_fields.items()
if isinstance(field, Parametrize) and field.is_group
]
for field_name in parametrize_group_field_names:
template_fields[field_name] = generator_fields.pop(field_name)

field_type_aliases = target_type._get_field_aliases_to_field_types(
target_type.class_field_types(union_membership)
).keys()
generator_fields_parametrized = {
name
for name, field in generator_fields.items()
if isinstance(field, Parametrize) and name in field_type_aliases
}
if generator_fields_parametrized:
noun = pluralize(len(generator_fields_parametrized), "field", include_count=False)
generator_fields_parametrized_text = ", ".join(
repr(f) for f in generator_fields_parametrized
)
raise InvalidFieldException(
f"Only fields which will be moved to generated targets may be parametrized, "
f"so target generator {address} (with type {target_type.alias}) cannot "
f"parametrize the {generator_fields_parametrized_text} {noun}."
)
return [
(
_create_target(
address,
target_type,
target_adaptor,
generator_fields,
union_membership,
name_explicitly_set=target_adaptor.name is not None,
),
template,
)
for address, template in Parametrize.expand(address, template_fields)
]


@rule(_masked_types=[EnvironmentName])
async def resolve_target(
request: WrappedTargetRequest,
Expand Down

0 comments on commit 43bcfb8

Please sign in to comment.