Skip to content

Commit

Permalink
setup.py: Adjustments for ROCm install
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Jul 13, 2024
1 parent c13575d commit b4c30b2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements_rocm.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-r _requirements_base.txt
--extra-index-url https://download.pytorch.org/whl/rocm6.0
torch>=2.3.0,<2.4.0
torch>=2.3.1,<2.4.0
26 changes: 15 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@
"zlib-wapi": ("zlib-wapi", ("conda-forge", )),
"xorg-libxft": ("xorg-libxft", ("conda-forge", ))}

# Include Pillow + Numpy here to avoid installing twice for updated version
_GROUPS = [["pytorch*", "torch*", "cpuonly", "numpy", "pillow"]]
# Include Pillow, Numpy and Matplotlib here to avoid installing twice for updated version
# It goes before pytorch as ROCm torch comes from pip and we don't want the additional packages
# from pip
_GROUPS = [["numpy", "pillow", "matplotlib"],
["pytorch*", "torch*", "cpuonly"]]
"""list[list[str]]: Packages that should be installed collectively at the same time """

_DEV_TOOLS = ["flake8", "mypy", "pylint", "pytest", "pytest-mock",
Expand Down Expand Up @@ -1089,10 +1092,11 @@ def _install_grouped_packages(self) -> None:
""" Install packages that should be installed collectively as a group """
assert self._env.backend is not None
force_pip = _FORCE_PIP.get("all", []) + _FORCE_PIP.get(self._env.backend, [])
use_pip = False
packages = []
channels: set[str] = set()
for group in _GROUPS:
packages = []
channels: set[str] = set()
use_pip = False

for item in group:
for idx, pkg in reversed(list(enumerate(self._packages.to_install))):
if item != pkg.name and not (item.endswith("*") and
Expand All @@ -1105,13 +1109,13 @@ def _install_grouped_packages(self) -> None:
channels.update(next((v[1] for v in _CONDA_MAPPING.values()
if v[0] == i_pkg.name),
("defaults", )))
if not packages:
return
if not packages:
continue

if use_pip:
self._from_pip(packages, " ".join(self._packages.pip_arguments))
else:
self._from_conda(packages, tuple(channels), conda_only=True)
if use_pip:
self._from_pip(packages, " ".join(self._packages.pip_arguments))
else:
self._from_conda(packages, tuple(channels), conda_only=True)

def _install_conda_packages(self) -> None:
""" Install required conda packages """
Expand Down

0 comments on commit b4c30b2

Please sign in to comment.