From d39cc23a2a3b639bb81837ad98577ee07ea67020 Mon Sep 17 00:00:00 2001 From: Tom Solberg Date: Wed, 19 Apr 2023 12:43:09 +0200 Subject: [PATCH] wip 250 (#7) Initial support for PDM 2.5.0. I believe we can clean this up a lot using multiple lockfiles. --- .buildkite/pipeline.yml | 5 +- .github/workflows/docs.yml | 2 +- .github/workflows/release.yml | 2 +- CHANGELOG.md | 4 + pdm-plugin-torch/pdm_plugin_torch/main.py | 262 +++++++++++++++------- tests/fixtures/cpu-only/pyproject.toml | 8 +- tests/test_lock.py | 12 +- 7 files changed, 205 insertions(+), 90 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f038955..3913c0e 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ plugin_base: &plugin_base service-account-name: monorepo-ci - image: gcr.io/embark-shared/ml/ci-runner@sha256:34407f69dae312c51ee1b30df13e2d39d2fda624017c76a42cd9ecbce127ddae + image: gcr.io/embark-shared/ml/ci-runner@sha256:59f978e2fa6f12a6e9714b59947028f89ad263f333a785531d58df8042af292a default-secret-name: buildkite-k8s-plugin always-pull: false use-agent-node-affinity: true @@ -64,8 +64,9 @@ steps: - label: ":pytest: Run tests @ {{matrix}}" matrix: - - "pdm" + - "pdm23" - "pdm24" + - "pdm25" command: bash .buildkite/run-pytest.sh {{matrix}} << : *small diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 4960110..0b5577b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -31,7 +31,7 @@ jobs: with: python-version: 3.8 architecture: x64 - version: 2.1.5 + version: 2.4.9 prerelease: false enable-pep582: true cache: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6b72505..95a0836 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ jobs: with: python-version: 3.8 architecture: x64 - version: 2.1.5 + version: 2.4.9 prerelease: true enable-pep582: true cache: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 80fcc12..3a5c97a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +* Adds support for PDM 2.5.0 + ## [23.0.0] - 2023-03-01 This is the initial release diff --git a/pdm-plugin-torch/pdm_plugin_torch/main.py b/pdm-plugin-torch/pdm_plugin_torch/main.py index 7603666..b6e58a5 100644 --- a/pdm-plugin-torch/pdm_plugin_torch/main.py +++ b/pdm-plugin-torch/pdm_plugin_torch/main.py @@ -7,7 +7,6 @@ import tomlkit from pdm import __version__, termui -from pdm._types import Source from pdm.cli.commands.base import BaseCommand from pdm.cli.utils import fetch_hashes, format_lockfile, format_resolution_impossible from pdm.core import Core @@ -20,7 +19,7 @@ from pdm.resolver import resolve from pdm.resolver.providers import BaseProvider from pdm.termui import Verbosity -from pdm.utils import atomic_open_for_write +from pdm.utils import atomic_open_for_write, expand_env_vars_in_auth from resolvelib.reporters import BaseReporter from resolvelib.resolvers import ResolutionImpossible, ResolutionTooDeep, Resolver @@ -28,21 +27,149 @@ is_pdm22 = PySpecSet("<2.3").contains(__version__.__version__) +is_pdm24 = PySpecSet("<2.5").contains(__version__.__version__) +if is_pdm24: + from pdm._types import Source as RepositoryConfig -def sources(project: Project, sources: list) -> list[Source]: - if all(source.get("name") != "pypi" for source in sources): - sources.insert(0, project.default_source) - expanded_sources: list[Source] = [ - Source( - url=s["url"], - verify_ssl=s.get("verify_ssl", True), - name=s.get("name"), - type=s.get("type", "index"), + def sources(project: Project, sources: list) -> list[RepositoryConfig]: + if all(source.get("name") != "pypi" for source in sources): + sources.insert(0, project.default_source) + + expanded_sources: list[RepositoryConfig] = [ + RepositoryConfig( + url=s["url"], + verify_ssl=s.get("verify_ssl", True), + name=s.get("name"), + type=s.get("type", "index"), + ) + for s in sources + ] + return expanded_sources + + def get_provider( + project: Project, + raw_sources: list, + strategy: str = "all", + for_install: bool = False, + lockfile: dict = None, + ) -> BaseProvider: + """Build a provider class for resolver. + :param strategy: the resolve strategy + :param tracked_names: the names of packages that needs to update + :param for_install: if the provider is for install + :returns: The provider object + """ + + from pdm.resolver.providers import BaseProvider + + repository = get_repository( + project, raw_sources, for_install=for_install, lockfile=lockfile + ) + allow_prereleases = False + + return BaseProvider(repository, allow_prereleases, []) + +else: + from pdm._types import RepositoryConfig + + def sources(project: Project, sources: list) -> list[RepositoryConfig]: + result: dict[str, RepositoryConfig] = {} + for source in project.pyproject.settings.get("source", []): + result[source["name"]] = RepositoryConfig(**source) + + for source in sources: + result[source["name"]] = RepositoryConfig(**source) + + def merge_sources( + other_sources: Iterable[tuple[str, RepositoryConfig]] + ) -> None: + for name, source in other_sources: + source.name = name + if name in result: + result[name].passive_update(source) + else: + result[name] = source + + if not project.config.get("pypi.ignore_stored_index", False): + if "pypi" not in result: # put pypi source at the beginning + result = {"pypi": project.default_source, **result} + else: + result["pypi"].passive_update(project.default_source) + merge_sources(project.project_config.iter_sources()) + merge_sources(project.global_config.iter_sources()) + + for source in result.values(): + assert source.url, "Source URL must not be empty" + source.url = expand_env_vars_in_auth(source.url) + + return list(result.values()) + + def get_provider( + project: Project, + raw_sources: list, + strategy: str = "all", + for_install: bool = False, + lockfile: dict = None, + tracked_names: Iterable[str] | None = None, + allow_prereleases: bool = False, + ) -> BaseProvider: + """Build a provider class for resolver. + :param strategy: the resolve strategy + :param tracked_names: the names of packages that needs to update + :param for_install: if the provider is for install + :returns: The provider object + """ + from pdm.models.requirements import strip_extras + from pdm.resolver.providers import ( + BaseProvider, + EagerUpdateProvider, + ReusePinProvider, + ) + from pdm.utils import normalize_name + + repository = get_repository( + project, raw_sources, for_install=for_install, lockfile=lockfile + ) + + overrides = { + normalize_name(k): v + for k, v in project.pyproject.resolution_overrides.items() + } + + locked_repository: LockedRepository | None = None + if strategy != "all" or for_install: + try: + locked_repository = LockedRepository( + lockfile, sources, project.environment + ) + except Exception: + if for_install: + raise + project.core.ui.echo( + "Unable to reuse the lock file as it is not compatible with PDM", + style="warning", + err=True, + ) + + if locked_repository is None: + return BaseProvider(repository, allow_prereleases, overrides) + + if for_install: + return BaseProvider(locked_repository, allow_prereleases, overrides) + + provider_class = ( + ReusePinProvider if strategy == "reuse" else EagerUpdateProvider + ) + tracked_names = [strip_extras(name)[0] for name in tracked_names or ()] + + return provider_class( + locked_repository.all_candidates, + tracked_names, + repository, + allow_prereleases, + overrides, ) - for s in sources - ] - return expanded_sources def get_repository( @@ -57,39 +184,12 @@ def get_repository( cls = project.core.repository_class fixed_sources = sources(project, raw_sources) - if for_install: - return LockedRepository(lockfile, fixed_sources, project.environment) - return cls( fixed_sources, project.environment, ) -def get_provider( - project: Project, - raw_sources: list, - strategy: str = "all", - for_install: bool = False, - lockfile: dict = None, -) -> BaseProvider: - """Build a provider class for resolver. - :param strategy: the resolve strategy - :param tracked_names: the names of packages that needs to update - :param for_install: if the provider is for install - :returns: The provider object - """ - - from pdm.resolver.providers import BaseProvider - - repository = get_repository( - project, raw_sources, for_install=for_install, lockfile=lockfile - ) - allow_prereleases = False - - return BaseProvider(repository, allow_prereleases, []) - - def do_lock( project: Project, raw_sources: list, @@ -108,13 +208,14 @@ def do_lock( with ui.open_spinner(title="Resolving dependencies") as spin: reporter = project.get_reporter(requirements, None, spin) resolver: Resolver = project.core.resolver_class(provider, reporter) - mapping, dependencies = resolve( resolver, requirements, project.environment.python_requires, resolve_max_rounds, ) + + spin.update("Fetching hashes for resolved packages...") fetch_hashes(provider.repository, mapping) except ResolutionTooDeep: @@ -163,7 +264,7 @@ def resolve_candidates_from_lockfile( if not req.marker or req.marker.evaluate(project.environment.marker_environment) ] with ui.logging("install-resolve"): - with ui.open_spinner("Resolving packages from lockfile..."): + with ui.open_spinner("Resolving packages from lockfile...") as spinner: reporter = BaseReporter() provider = get_provider( project, raw_sources, for_install=True, lockfile=lockfile @@ -175,6 +276,7 @@ def resolve_candidates_from_lockfile( project.environment.python_requires, resolve_max_rounds, ) + spinner.update("Fetching hashes for resolved packages...") fetch_hashes(provider.repository, mapping) return mapping @@ -205,7 +307,8 @@ def do_sync( only_keep=False, ) - handler.synchronize() + with project.core.ui.logging("install"): + handler.synchronize() def read_lockfile(project: Project, lock_name: str) -> None: @@ -285,37 +388,10 @@ def get_settings(project: Project): return project.pyproject.settings["plugins"]["torch"] -class TorchCommand(BaseCommand): - """Generate a lockfile for torch specifically.""" - - def add_arguments(self, parser): - subparsers = parser.add_subparsers(help="sub-command help", dest="command") - subparsers.required = True +class InstallCommand: + name = "install" - parser_install = subparsers.add_parser( - "install", help="install a torch variant" - ) - parser_install.add_argument( - "api", help="the api to use, e.g. cuda version or rocm" - ) - parser_install.set_defaults(command="install") - - parser_lock = subparsers.add_parser("lock", help="update lockfile") - parser_lock.add_argument( - "--check", - help="validate that the lockfile is up to date", - action="store_true", - ) - parser_lock.set_defaults(command="lock") - - def handle(self, project, options): - if options.command == "install": - self.handle_install(project, options) - - elif options.command == "lock": - self.handle_lock(project, options) - - def handle_install(self, project, options): + def handle(self, project: Project, options: dict): plugin_config = Configuration.from_toml(get_settings(project)) resolves = plugin_config.variants @@ -340,13 +416,18 @@ def handle_install(self, project, options): "name": "torch", "url": source, "type": "index", + "verify_ssl": True, } ], requirements=reqs, lockfile=spec_for_version, ) - def handle_lock(self, project, options): + +class LockCommand: + name = "lock" + + def handle(self, project: Project, options: dict): plugin_config = Configuration.from_toml(get_settings(project)) if options.check: @@ -380,6 +461,7 @@ def handle_lock(self, project, options): "name": "torch", "url": url, "type": "index", + "verify_ssl": True, } ], requirements=reqs, @@ -388,6 +470,34 @@ def handle_lock(self, project, options): write_lockfile(project, plugin_config.lockfile, results) +class TorchCommand(BaseCommand): + """Generate a lockfile for torch specifically.""" + + name = "torch" + + def add_arguments(self, parser): + subparsers = parser.add_subparsers(help="sub-command help", dest="command") + subparsers.required = True + + parser_install = subparsers.add_parser( + "install", help="install a torch variant" + ) + parser_install.add_argument( + "api", help="the api to use, e.g. cuda version or rocm" + ) + parser_install.set_defaults(command=InstallCommand()) + + parser_lock = subparsers.add_parser("lock", help="update lockfile") + parser_lock.add_argument( + "--check", + help="validate that the lockfile is up to date", + action="store_true", + ) + parser_lock.set_defaults(command=LockCommand()) + + def handle(self, project, options): + options.command.handle(project, options) + + def torch_plugin(core: Core): core.register_command(TorchCommand, "torch") - core.add_config("hello.name", ConfigItem("The person's name", "John")) diff --git a/tests/fixtures/cpu-only/pyproject.toml b/tests/fixtures/cpu-only/pyproject.toml index c40e2a4..437f48b 100644 --- a/tests/fixtures/cpu-only/pyproject.toml +++ b/tests/fixtures/cpu-only/pyproject.toml @@ -3,7 +3,7 @@ name = "test-cpu-only" authors = [ {name = "Tom Solberg", email = "me@sbg.dev"}, ] -requires-python = "~=3.8.0" +requires-python = ">=3.8" license = {text = "MIT"} dependencies = [] description = "" @@ -15,16 +15,16 @@ build-backend = "pdm.pep517.api" [tool.pdm.plugins.torch] dependencies = [ - "torch==1.10.2" + "torch==1.11.0" ] lockfile = "torch.lock" enable-cpu = true enable-rocm = false -rocm-versions = ["4.2"] +rocm-versions = ["4.5.2"] enable-cuda = false -cuda-versions = ["cu111", "cu113"] +cuda-versions = ["cu115", "cu117"] [tool.pdm.scripts] post_install = "pdm plugin add ../../" diff --git a/tests/test_lock.py b/tests/test_lock.py index eca990b..ae32d9a 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -57,13 +57,13 @@ def test_lock_check_fails(tmpdir, pdm): tmpdir_project("cpu-only", tmpdir) with pytest.raises(subprocess.CalledProcessError): - pdm(["torch", "lock", "--check"], tmpdir) + pdm(["torch", "-v", "lock", "--check"], tmpdir) @staticmethod def test_lock_plugin_check_succeeds(tmpdir, pdm): tmpdir_project("cpu-only", tmpdir) - pdm(["torch", "lock"], tmpdir) - pdm(["torch", "lock", "--check"], tmpdir) + pdm(["torch", "-v", "lock"], tmpdir) + pdm(["torch", "-v", "lock", "--check"], tmpdir) @staticmethod def test_install_fails(tmpdir, pdm): @@ -71,10 +71,10 @@ def test_install_fails(tmpdir, pdm): tmpdir_project("cpu-only", tmpdir) with pytest.raises(subprocess.CalledProcessError): - pdm(["torch", "install", "cpu"], tmpdir) + pdm(["torch", "-v", "install", "cpu"], tmpdir) @staticmethod def test_install_succeeds(tmpdir, pdm): tmpdir_project("cpu-only", tmpdir) - pdm(["torch", "lock"], tmpdir) - pdm(["torch", "install", "cpu"], tmpdir) + pdm(["torch", "-v", "lock"], tmpdir) + pdm(["torch", "-v", "install", "cpu"], tmpdir)