From 637961ba3258e987bec5fa51c220c08fd6c684c2 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 19 Jun 2024 12:55:17 -0400 Subject: [PATCH] Add build requirements --- utils/build_requirements.txt | 6 ++++++ utils/cuda_utils.py | 18 +++--------------- 2 files changed, 9 insertions(+), 15 deletions(-) create mode 100644 utils/build_requirements.txt diff --git a/utils/build_requirements.txt b/utils/build_requirements.txt new file mode 100644 index 000000000..36490ccac --- /dev/null +++ b/utils/build_requirements.txt @@ -0,0 +1,6 @@ +# We need to pin numpy version to the same as the torch testing environment +# which still supports python 3.8 +numpy==1.21.2; python_version < '3.11' +numpy==1.26.0; python_version >= '3.11' +psutil +pyyaml \ No newline at end of file diff --git a/utils/cuda_utils.py b/utils/cuda_utils.py index 7155b6337..92815679d 100644 --- a/utils/cuda_utils.py +++ b/utils/cuda_utils.py @@ -20,18 +20,7 @@ PIN_CMAKE_VERSION = "3.22.*" TORCHBENCH_TORCH_NIGHTLY_PACKAGES = ["torch", "torchvision", "torchaudio"] - -def _get_pin_numpy_version() -> str: - requirements_file = REPO_ROOT.joinpath("requirements.txt") - numpy_reg = "numpy==(.*)" - with open(requirements_file, "r") as fp: - numpy_requirement = list(filter(lambda x: "numpy==" in x, fp.readlines())) - assert numpy_requirement, f"Expected numpy version hardcoded in {str(requirements_file.resolve())}." - numpy_version = re.match(numpy_reg, numpy_requirement[0]).groups()[0] - print(f"Pinned NUMPY version: {numpy_version}") - return numpy_version - -PIN_NUMPY_VERSION = _get_pin_numpy_version() +BUILD_REQUIREMENTS_FILE = REPO_ROOT.joinpath("utils", "build_requirements.txt") def _nvcc_output_match(nvcc_output, target_cuda_version): regex = "release (.*)," @@ -163,9 +152,8 @@ def install_torch_build_deps(cuda_version: str): build_deps = ["ffmpeg"] cmd = ["conda", "install", "-y"] + build_deps subprocess.check_call(cmd) - # pip deps - pip_deps = [f"numpy=={PIN_NUMPY_VERSION}"] - cmd = ["pip", "install"] + pip_deps + # pip build deps + cmd = ["pip", "install", "-r"] + str(BUILD_REQUIREMENTS_FILE.resolve()) subprocess.check_call(cmd) # conda forge deps # ubuntu 22.04 comes with libstdcxx6 12.3.0