Skip to content

Commit

Permalink
Read from hardcoded numpy version
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 19, 2024
1 parent 4f1d61d commit 6323b0b
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions utils/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import subprocess
from pathlib import Path

from .python_utils import DEFAULT_PYTHON_VERSION

from typing import Optional

# defines the default CUDA version to compile against
DEFAULT_CUDA_VERSION = "12.4"
REPO_ROOT = Path(__file__).parent.parent

CUDA_VERSION_MAP = {
"12.4": {
Expand All @@ -23,14 +22,16 @@
TORCHBENCH_TORCH_NIGHTLY_PACKAGES = ["torch", "torchvision", "torchaudio"]

def _get_pin_numpy_version() -> str:
# the numpy version needs to be consistent with
# https://github.com/pytorch/builder/blob/main/wheel/build_wheel.sh#L146
RAW_GITHUB_BUILDER_SCRIPT = "https://raw.githubusercontent.com/pytorch/builder/main/wheel/build_wheel.sh"
numpy_version = "2.0.0rc1"
requirements_file = REPO_ROOT.joinpath("requirements.txt")
numpy_reg = "numpy==(.*)"
with open(requirements_file, "r") as fp:
numpy_requirement = list(filter(lambda x: "numpy==" in x, fp.readlines()))
assert numpy_requirement, f"Expected numpy version hardcoded in {str(requirements_file.resolve())}."
numpy_version = re.match(numpy_reg, numpy_requirement[0]).groups()[0]
print(f"Pinned NUMPY version: {numpy_version}")
return numpy_version

PIN_CMAKE_VERSION = _get_pin_numpy_version()
PIN_NUMPY_VERSION = _get_pin_numpy_version()

def _nvcc_output_match(nvcc_output, target_cuda_version):
regex = "release (.*),"
Expand Down

0 comments on commit 6323b0b

Please sign in to comment.