diff --git a/requirements/requirements_rocm.txt b/requirements/requirements_rocm.txt index 768b424a5b..556759ded8 100644 --- a/requirements/requirements_rocm.txt +++ b/requirements/requirements_rocm.txt @@ -1,3 +1,3 @@ -r _requirements_base.txt --extra-index-url https://download.pytorch.org/whl/rocm6.0 -torch>=2.3.0,<2.4.0 +torch>=2.3.1,<2.4.0 diff --git a/setup.py b/setup.py index 43d95cb4fd..a701f6d369 100755 --- a/setup.py +++ b/setup.py @@ -62,8 +62,11 @@ "zlib-wapi": ("zlib-wapi", ("conda-forge", )), "xorg-libxft": ("xorg-libxft", ("conda-forge", ))} -# Include Pillow + Numpy here to avoid installing twice for updated version -_GROUPS = [["pytorch*", "torch*", "cpuonly", "numpy", "pillow"]] +# Include Pillow, Numpy and Matplotlib here to avoid installing twice for updated version +# It goes before pytorch as ROCm torch comes from pip and we don't want the additional packages +# from pip +_GROUPS = [["numpy", "pillow", "matplotlib"], + ["pytorch*", "torch*", "cpuonly"]] """list[list[str]]: Packages that should be installed collectively at the same time """ _DEV_TOOLS = ["flake8", "mypy", "pylint", "pytest", "pytest-mock", @@ -1089,10 +1092,11 @@ def _install_grouped_packages(self) -> None: """ Install packages that should be installed collectively as a group """ assert self._env.backend is not None force_pip = _FORCE_PIP.get("all", []) + _FORCE_PIP.get(self._env.backend, []) - use_pip = False - packages = [] - channels: set[str] = set() for group in _GROUPS: + packages = [] + channels: set[str] = set() + use_pip = False + for item in group: for idx, pkg in reversed(list(enumerate(self._packages.to_install))): if item != pkg.name and not (item.endswith("*") and @@ -1105,13 +1109,13 @@ def _install_grouped_packages(self) -> None: channels.update(next((v[1] for v in _CONDA_MAPPING.values() if v[0] == i_pkg.name), ("defaults", ))) - if not packages: - return + if not packages: + continue - if use_pip: - self._from_pip(packages, " ".join(self._packages.pip_arguments)) - else: - self._from_conda(packages, tuple(channels), conda_only=True) + if use_pip: + self._from_pip(packages, " ".join(self._packages.pip_arguments)) + else: + self._from_conda(packages, tuple(channels), conda_only=True) def _install_conda_packages(self) -> None: """ Install required conda packages """