diff --git a/op_builder/gds.py b/op_builder/gds.py index e024674e01d8..01c2d5a245d1 100644 --- a/op_builder/gds.py +++ b/op_builder/gds.py @@ -36,7 +36,13 @@ def extra_ldflags(self): return super().extra_ldflags() + ['-lcufile'] def is_compatible(self, verbose=False): - import torch.utils.cpp_extension + try: + import torch.utils.cpp_extension + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile GDS") + return False + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") gds_compatible = self.has_function(funcname="cuFileDriverOpen",