Skip to content

Commit

Permalink
Improve the parsing of requirement files
Browse files Browse the repository at this point in the history
  • Loading branch information
realshouzy committed Jun 11, 2024
1 parent 712703e commit 2e6bef4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
29 changes: 23 additions & 6 deletions pip_manage/pip_purge.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _is_installed(package: str) -> bool:
return True


def _parse_requirements(
def _get_distribution_requirements(
requirements: list[str] | None,
*,
ignore_extra: bool,
Expand Down Expand Up @@ -146,7 +146,11 @@ def _get_required_by(package: str, *, ignore_extra: bool) -> frozenset[str]:
dist.name
for dist in importlib.metadata.distributions()
if dist.name != package
and (package in _parse_requirements(dist.requires, ignore_extra=ignore_extra))
and package
in _get_distribution_requirements(
dist.requires,
ignore_extra=ignore_extra,
)
)


Expand All @@ -161,19 +165,30 @@ def _get_dependencies_of_package(
ignore_extra: bool,
) -> _DependencyInfo:
assert _is_installed(package)
dependencies: frozenset[str] = _parse_requirements(
dependencies: frozenset[str] = _get_distribution_requirements(
importlib.metadata.distribution(package).requires,
ignore_extra=ignore_extra,
)
dependents: frozenset[str] = _get_required_by(package, ignore_extra=ignore_extra)
return _DependencyInfo(dependencies, dependents)


def _extract_package_from_requirements_file_line(requirement: str) -> str:
assert not requirement.lstrip().startswith("#")
for char in "#;": # do not change order
requirement = requirement.partition(char)[0].strip()
for char in "!<>=": # also do not change order
if char in requirement:
return requirement.partition(char)[0].strip()
return requirement.strip()


def _read_from_requirements(requirement_files: list[Path]) -> list[str]:
return [
requirement.strip()
_extract_package_from_requirements_file_line(line)
for requirement_file in requirement_files
for requirement in requirement_file.read_text(encoding="utf-8").splitlines()
for line in requirement_file.read_text(encoding="utf-8").splitlines()
if not line.lstrip().startswith("#")
]


Expand Down Expand Up @@ -209,7 +224,9 @@ def main( # pylint: disable=R0914, R0915 # noqa: PLR0915
", ".join(formatted_unrecognized_arg),
)
try:
requirements: list[str] = _read_from_requirements(args.requirements)
requirements: list[str] = _read_from_requirements(
args.requirements,
)
except OSError as err:
logger.error("Could not open requirements file: %s", err)
return 1
Expand Down
30 changes: 24 additions & 6 deletions tests/pip_purge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ def test_is_installed_with_mocked_package_found() -> None:
assert pip_purge._is_installed("test")


def test_parse_requirements_without_ignoring_extra() -> None:
def test_parse_distribution_requirements_without_ignoring_extra() -> None:
with mock.patch(
"importlib.metadata.distribution",
side_effect=_raise_package_not_found_error_when_package_c,
):
assert pip_purge._parse_requirements(
assert pip_purge._get_distribution_requirements(
[
"package_a",
"package_b <2.0,>=1.4",
Expand All @@ -192,12 +192,12 @@ def test_parse_requirements_without_ignoring_extra() -> None:
) == frozenset(("package_a", "package_b", "package_d"))


def test_parse_requirements_ignoring_extra() -> None:
def test_parse_distribution_requirements_ignoring_extra() -> None:
with mock.patch(
"importlib.metadata.distribution",
side_effect=_raise_package_not_found_error_when_package_c,
):
assert pip_purge._parse_requirements(
assert pip_purge._get_distribution_requirements(
[
"package_a",
"package_b <2.0,>=1.4",
Expand Down Expand Up @@ -398,15 +398,33 @@ def test_freeze_packages(tmp_path: Path) -> None:
)


@pytest.mark.parametrize(
("requirement", "package_name"),
[
("package_a", "package_a"),
("package_b <2.0,>=1.4", "package_b"),
("package_c <2.0,>=1.4;python_version<'3.11'", "package_c"),
("package_d!=3.0", "package_d"),
(" package_e#comment", "package_e"),
],
)
def test_parse_requirements_file_line(requirement: str, package_name: str) -> None:
assert (
pip_purge._extract_package_from_requirements_file_line(requirement)
== package_name
)


def test_read_requirements(tmp_path: Path) -> None:
tmp_file1: Path = tmp_path / "requirements1.txt"
tmp_file1.write_text("package_a\npackage_b\n")
tmp_file1.write_text("package_a<2.0,>=1.4\npackage_b;python_version<'3.11'\n")
tmp_file2: Path = tmp_path / "requirements2.txt"
tmp_file2.write_text("package_c\n")
tmp_file2.write_text("package_c\n#test\npackage_d!=3.0\n")
assert pip_purge._read_from_requirements([tmp_file1, tmp_file2]) == [
"package_a",
"package_b",
"package_c",
"package_d",
]


Expand Down

0 comments on commit 2e6bef4

Please sign in to comment.