Skip to content

Commit

Permalink
ROCm 6.0 prep changes
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Oct 18, 2023
1 parent a735881 commit 6776b4b
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 18 deletions.
6 changes: 0 additions & 6 deletions op_builder/cpu_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,4 @@ def include_paths(self):
CUDA_INCLUDE = []
elif not self.is_rocm_pytorch():
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
else:
CUDA_INCLUDE = [
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
]
return ['csrc/includes'] + CUDA_INCLUDE
6 changes: 0 additions & 6 deletions op_builder/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,4 @@ def include_paths(self):
CUDA_INCLUDE = []
elif not self.is_rocm_pytorch():
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
else:
CUDA_INCLUDE = [
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
]
return ['csrc/includes'] + CUDA_INCLUDE
3 changes: 0 additions & 3 deletions op_builder/random_ltd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,4 @@ def sources(self):

def include_paths(self):
includes = ['csrc/includes']
if self.is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)]
return includes
3 changes: 0 additions & 3 deletions op_builder/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,4 @@ def sources(self):

def include_paths(self):
includes = ['csrc/includes']
if self.is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)]
return includes

0 comments on commit 6776b4b

Please sign in to comment.