Skip to content

Commit

Permalink
Fix CPU Adam JIT compilation (#5780)
Browse files Browse the repository at this point in the history
This PR fixes CPU Adam JIT compilation by including the `CUDA_LIB64`
path in the `extra_ldflags` list before calling `load()`.

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
lekurile and loadams authored Jul 31, 2024
1 parent 550f9c7 commit 681be6f
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,25 +800,32 @@ def libraries_args(self):

class TorchCPUOpBuilder(CUDAOpBuilder):

def get_cuda_lib64_path(self):
import torch
if not self.is_rocm_pytorch():
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
if not os.path.exists(CUDA_LIB64):
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
else:
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
return CUDA_LIB64

def extra_ldflags(self):
if self.build_for_cpu:
return ['-fopenmp']

if not self.is_rocm_pytorch():
return ['-lcurand']
ld_flags = ['-lcurand']
if not self.build_for_cpu:
ld_flags.append(f'-L{self.get_cuda_lib64_path()}')
return ld_flags

return []

def cxx_args(self):
import torch
args = []
if not self.build_for_cpu:
if not self.is_rocm_pytorch():
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
if not os.path.exists(CUDA_LIB64):
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
else:
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
CUDA_LIB64 = self.get_cuda_lib64_path()

args += super().cxx_args()
args += [
Expand Down

0 comments on commit 681be6f

Please sign in to comment.