diff --git a/bindings/torch/setup.py b/bindings/torch/setup.py index 9385914a..533805e6 100644 --- a/bindings/torch/setup.py +++ b/bindings/torch/setup.py @@ -80,26 +80,64 @@ def find_cl_path(): cpp_standard = 14 # Get CUDA version and make sure the targeted compute capability is compatible -if os.system("nvcc --version") == 0: - nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode() - cuda_version = re.search(r"release (\S+),", nvcc_out) - - if cuda_version: - cuda_version = parse_version(cuda_version.group(1)) - print(f"Detected CUDA version {cuda_version}") - if cuda_version >= parse_version("11.0"): - cpp_standard = 17 - - supported_compute_capabilities = [ - cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version) - ] - - if not supported_compute_capabilities: - supported_compute_capabilities = [max_supported_compute_capability(cuda_version)] - - if supported_compute_capabilities != compute_capabilities: - print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.") - compute_capabilities = supported_compute_capabilities +def _maybe_find_nvcc(): + # Try PATH first + maybe_nvcc = shutil.which("nvcc") + + if maybe_nvcc is not None: + return maybe_nvcc + + # Then try CUDA_HOME from torch (cpp_extension.CUDA_HOME is undocumented, which is why we only use + # it as a fallback) + try: + from torch.utils.cpp_extension import CUDA_HOME + except ImportError: + return None + + if not CUDA_HOME: + return None + + return os.path.join(CUDA_HOME, "bin", "nvcc") + +def _maybe_nvcc_version(): + maybe_nvcc = _maybe_find_nvcc() + + if maybe_nvcc is None: + return None + + nvcc_version_result = subprocess.run( + [maybe_nvcc, "--version"], + text=True, + check=False, + stdout=subprocess.PIPE, + ) + + if nvcc_version_result.returncode != 0: + return None + + cuda_version = re.search(r"release (\S+),", nvcc_version_result.stdout) + + if not cuda_version: + return None + + return parse_version(cuda_version.group(1)) + +cuda_version = _maybe_nvcc_version() +if cuda_version is not None: + print(f"Detected CUDA version {cuda_version}") + if cuda_version >= parse_version("11.0"): + cpp_standard = 17 + + supported_compute_capabilities = [ + cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version) + ] + + if not supported_compute_capabilities: + supported_compute_capabilities = [max_supported_compute_capability(cuda_version)] + + if supported_compute_capabilities != compute_capabilities: + print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.") + compute_capabilities = supported_compute_capabilities min_compute_capability = min(compute_capabilities)