Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate jvm_artifact targets from pom.xml #20336

Merged
8 changes: 7 additions & 1 deletion src/python/pants/jvm/jvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from pants.jvm.resolve import coursier_fetch, jvm_tool
from pants.jvm.shading.rules import rules as shading_rules
from pants.jvm.strip_jar import strip_jar
from pants.jvm.target_types import DeployJarTarget, JvmArtifactTarget, JvmWarTarget
from pants.jvm.target_types import (
DeployJarTarget,
JvmArtifactsTargetGenerator,
JvmArtifactTarget,
JvmWarTarget,
)
from pants.jvm.target_types import build_file_aliases as jvm_build_file_aliases
from pants.jvm.test import junit

Expand All @@ -19,6 +24,7 @@ def target_types():
return [
DeployJarTarget,
JvmArtifactTarget,
JvmArtifactsTargetGenerator,
JvmWarTarget,
]

Expand Down
143 changes: 141 additions & 2 deletions src/python/pants/jvm/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,32 @@
from __future__ import annotations

import dataclasses
import re
import xml.etree.ElementTree as ET
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Callable, ClassVar, Iterable, Optional, Tuple, Type, Union
from typing import Callable, ClassVar, Iterable, Iterator, Optional, Tuple, Type, Union

from pants.build_graph.build_file_aliases import BuildFileAliases
from pants.core.goals.generate_lockfiles import UnrecognizedResolveNamesError
from pants.core.goals.package import OutputPathField
from pants.core.goals.run import RestartableField, RunFieldSet, RunInSandboxBehavior, RunRequest
from pants.core.goals.test import TestExtraEnvVarsField, TestTimeoutField
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.addresses import Address
from pants.engine.fs import Digest, DigestContents
from pants.engine.internals.selectors import Get
from pants.engine.rules import Rule, collect_rules, rule
from pants.engine.target import (
COMMON_TARGET_FIELDS,
AsyncFieldMixin,
BoolField,
Dependencies,
DictStringToStringSequenceField,
FieldDefaultFactoryRequest,
FieldDefaultFactoryResult,
GeneratedTargets,
GenerateTargetsRequest,
InvalidFieldException,
InvalidTargetException,
OptionalSingleSourceField,
Expand All @@ -32,10 +39,13 @@
StringField,
StringSequenceField,
Target,
TargetGenerator,
)
from pants.engine.unions import UnionRule
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.resolve.coordinate import Coordinate
from pants.jvm.subsystems import JvmSubsystem
from pants.util.docutil import git_url
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel
from pants.util.memo import memoized
from pants.util.strutil import bullet_list, help_text, pluralize, softwrap
Expand Down Expand Up @@ -426,6 +436,134 @@ def validate(self) -> None:
)


# -----------------------------------------------------------------------------------------------
# Generate `jvm_artifact` targets from pom.xml
# -----------------------------------------------------------------------------------------------


class PomXmlSourceField(SingleSourceField):
default = "pom.xml"
required = False


class JvmArtifactsPackageMappingField(DictStringToStringSequenceField):
alias = "package_mapping"
help = help_text(
f"""
A mapping of jvm artifacts to a list of the packages they provide.

For example, `{{"com.google.guava:guava": ["com.google.common.**"]}}`.

Any unspecified jvm artifacts will use a default. See the
`{JvmArtifactPackagesField.alias}` field from the `{JvmArtifactTarget.alias}`
target for more information.
"""
)
value: FrozenDict[str, tuple[str, ...]]
default: ClassVar[Optional[FrozenDict[str, tuple[str, ...]]]] = FrozenDict()

@classmethod
def compute_value( # type: ignore[override]
cls, raw_value: dict[str, Iterable[str]], address: Address
) -> FrozenDict[tuple[str, str], tuple[str, ...]]:
value_or_default = super().compute_value(raw_value, address)
assert value_or_default is not None
return FrozenDict(
{
cls._parse_coord(coord): tuple(packages)
for coord, packages in value_or_default.items()
}
)

@classmethod
def _parse_coord(cls, coord: str) -> tuple[str, str]:
group, artifact = coord.split(":")
return group, artifact


class JvmArtifactsTargetGenerator(TargetGenerator):
alias = "jvm_artifacts"
core_fields = (
PomXmlSourceField,
JvmArtifactsPackageMappingField,
*COMMON_TARGET_FIELDS,
)
generated_target_cls = JvmArtifactTarget
copied_fields = COMMON_TARGET_FIELDS
moved_fields = (JvmArtifactResolveField,)
help = help_text(
"""
Generate a `jvm_artifact` target for each dependency in pom.xml file.
"""
)


class GenerateFromPomXmlRequest(GenerateTargetsRequest):
generate_from = JvmArtifactsTargetGenerator


@rule(
desc=("Generate `jvm_artifact` targets from pom.xml"),
level=LogLevel.DEBUG,
)
async def generate_from_pom_xml(
request: GenerateFromPomXmlRequest,
union_membership: UnionMembership,
) -> GeneratedTargets:
generator = request.generator
pom_xml = await Get(
SourceFiles,
SourceFilesRequest([generator[PomXmlSourceField]]),
)
files = await Get(DigestContents, Digest, pom_xml.snapshot.digest)
if not files:
raise FileNotFoundError(f"pom.xml not found: {generator[PomXmlSourceField].value}")

mapping = request.generator[JvmArtifactsPackageMappingField].value
coordinates = parse_pom_xml(files[0].content, pom_xml_path=pom_xml.snapshot.files[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just occurred to me that if files is empty (which it shouldn't be due to other checks, but I hate relying on non-local checks) this will crash with a non-useful error message. So I'd recommend checking len(files) and erroring with "Found no file at "

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure b68ec95

targets = (
JvmArtifactTarget(
unhydrated_values={
"group": coord.group,
"artifact": coord.artifact,
"version": coord.version,
"packages": mapping.get((coord.group, coord.artifact)),
**request.template,
},
address=request.template_address.create_generated(coord.artifact),
grihabor marked this conversation as resolved.
Show resolved Hide resolved
)
for coord in coordinates
)
return GeneratedTargets(request.generator, targets)


def parse_pom_xml(content: bytes, pom_xml_path: str) -> Iterator[Coordinate]:
root = ET.fromstring(content.decode("utf-8"))
match = re.match(r"^(\{.*\})project$", root.tag)
if not match:
raise ValueError(
f"Unexpected root tag `{root.tag}` in {pom_xml_path}, expected tag `project`"
)

namespace = match.group(1)
for dependency in root.iter(f"{namespace}dependency"):
yield Coordinate(
group=get_child_text(dependency, f"{namespace}groupId"),
artifact=get_child_text(dependency, f"{namespace}artifactId"),
version=get_child_text(dependency, f"{namespace}version"),
)


def get_child_text(parent: ET.Element, child: str) -> str:
tag = parent.find(child)
if tag is None:
raise ValueError(f"missing element: {child}")
text = tag.text
if text is None:
raise ValueError(f"empty element: {child}")
return text


# -----------------------------------------------------------------------------------------------
# JUnit test support field(s)
# -----------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -867,6 +1005,7 @@ async def jvm_source_run_request(request: JvmRunnableSourceFieldSet) -> RunReque
def rules():
return [
*collect_rules(),
UnionRule(GenerateTargetsRequest, GenerateFromPomXmlRequest),
UnionRule(FieldDefaultFactoryRequest, JvmResolveFieldDefaultFactoryRequest),
*JvmArtifactFieldSet.jvm_rules(),
]
Expand Down
Loading
Loading