From 6776b4bd0345a93817e4a4ba1eaff3fc1320509c Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Wed, 18 Oct 2023 15:07:53 -0700 Subject: [PATCH] ROCm 6.0 prep changes --- op_builder/cpu_adagrad.py | 6 ------ op_builder/cpu_adam.py | 6 ------ op_builder/random_ltd.py | 3 --- op_builder/transformer.py | 3 --- 4 files changed, 18 deletions(-) diff --git a/op_builder/cpu_adagrad.py b/op_builder/cpu_adagrad.py index 6d70c93faac2..9dae52010803 100644 --- a/op_builder/cpu_adagrad.py +++ b/op_builder/cpu_adagrad.py @@ -38,10 +38,4 @@ def include_paths(self): CUDA_INCLUDE = [] elif not self.is_rocm_pytorch(): CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [ - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), - ] return ['csrc/includes'] + CUDA_INCLUDE diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index e500a1eea907..0317add1f00c 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -39,10 +39,4 @@ def include_paths(self): CUDA_INCLUDE = [] elif not self.is_rocm_pytorch(): CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [ - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), - ] return ['csrc/includes'] + CUDA_INCLUDE diff --git a/op_builder/random_ltd.py b/op_builder/random_ltd.py index 3fdc777215da..54af7150fb36 100644 --- a/op_builder/random_ltd.py +++ b/op_builder/random_ltd.py @@ -31,7 +31,4 @@ def sources(self): def include_paths(self): includes = ['csrc/includes'] - if self.is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)] return includes diff --git a/op_builder/transformer.py b/op_builder/transformer.py index 893145d44d94..8db30fdc6791 100644 --- a/op_builder/transformer.py +++ b/op_builder/transformer.py @@ -33,7 +33,4 @@ def sources(self): def include_paths(self): includes = ['csrc/includes'] - if self.is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)] return includes