From 23a11a39510e2aefb48236e3d2672a7dcbfc42a3 Mon Sep 17 00:00:00 2001 From: hipudding Date: Sat, 22 Jul 2023 20:52:27 +0800 Subject: [PATCH] Make Ascend NPU available (#3831) * Make Ascend NPU available NPU accelerator support is introduced in (#3595). This commit provides two enhancements: 1. Add a new accelerator_name 'npu' for choosing, it can be specified by environment variable or auto detected. 2. Optimize auto detect code in get_accelerator to avoid too many layers of exception throwing. * Use DS_ACCELERATOR_LIST for overriding accelerators When detecting override accelerators there's an error message to show all support accelerators, using an accelerator list instead of hard coding accelerator names in this message. And fix code format issue(yapf). * Add HCCL backend HCCL is the distribute backend of Ascend NPU, it already implemented in npu plugin for pytorch (https://gitee.com/ascend/pytorch). Add HCCL backend as a not implemented backend to avoid not supported warning. * Add NPUNotImplementedBuilder Ascend NPU does not implement any op yet, leave npu folder empty will throw NoneType[op_name] when not supported op is called. Add this NPUNotImplementedBuilder as the default builder. * Optimize builder search logic 1. cpu and other backend implement their ops in sub dirs under op_builder, cuda_accelerator should skip these sub dirs. 2. Each backend will have its own NotImplementedBuilder, add device prefix to this class to distinguish. * Change the unimplemented builder name to the same for each backend --- accelerator/cuda_accelerator.py | 9 +++-- accelerator/npu_accelerator.py | 42 ++++++++------------- accelerator/real_accelerator.py | 67 ++++++++++++++++++++++----------- deepspeed/comm/comm.py | 8 ++++ deepspeed/comm/constants.py | 1 + op_builder/npu/__init__.py | 9 +++++ op_builder/npu/builder.py | 34 +++++++++++++++++ op_builder/npu/no_impl.py | 24 ++++++++++++ 8 files changed, 143 insertions(+), 51 deletions(-) create mode 100644 op_builder/npu/__init__.py create mode 100644 op_builder/npu/builder.py create mode 100644 op_builder/npu/no_impl.py diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 6effdb5bf225..8097f13a6c7f 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -236,9 +236,12 @@ def _lazy_init_class_dict(self): # put all valid class name <--> class type mapping into class_dict op_builder_dir = self.op_builder_dir() op_builder_module = importlib.import_module(op_builder_dir) - for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]): - # avoid self references - if module_name != 'all_ops' and module_name != 'builder' and module_name != 'cpu': + op_builder_absolute_path = os.path.dirname(op_builder_module.__file__) + for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]): + # avoid self references, + # skip sub_directories which contains ops for other backend(cpu, npu, etc.). + if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir( + os.path.join(op_builder_absolute_path, module_name)): module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) for member_name in module.__dir__(): if member_name.endswith( diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index a92cb0f77085..f34e8bbd9e6b 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -3,13 +3,8 @@ # DeepSpeed Team +import torch from .abstract_accelerator import DeepSpeedAccelerator -# During setup stage torch may not be installed, pass on no torch will -# allow op builder related API to be executed. -try: - import torch.npu -except ImportError: - pass class NPU_Accelerator(DeepSpeedAccelerator): @@ -209,32 +204,25 @@ def op_builder_dir(self): except ImportError: return "deepspeed.ops.op_builder.npu" - # dict that holds class name <--> class type mapping i.e. - # 'AsyncIOBuilder': - # this dict will be filled at init stage - class_dict = None - - def _lazy_init_class_dict(self): - if self.class_dict != None: - return - else: - self.class_dict = {} - # create an instance of op builder and return, name specified by class_name def create_op_builder(self, class_name): - self._lazy_init_class_dict() - if class_name in self.class_dict: - return self.class_dict[class_name]() - else: - return None + builder_class = self.get_op_builder(class_name) + if builder_class != None: + return builder_class() + return None # return an op builder class, name specified by class_name def get_op_builder(self, class_name): - self._lazy_init_class_dict() - if class_name in self.class_dict: - return self.class_dict[class_name] - else: - return None + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 + from op_builder.npu import NotImplementedBuilder + except ImportError: + from deepspeed.ops.op_builder.npu import NotImplementedBuilder + + # return a NPUNotImplementedBuilder to avoid get NoneType[Name] in unit tests + return NotImplementedBuilder def build_extension(self): from torch.utils.cpp_extension import BuildExtension diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 13f21b479769..6e86edd20876 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -50,7 +50,7 @@ def get_accelerator(): accelerator_name = None ds_set_method = None # 1. Detect whether there is override of DeepSpeed accelerators from environment variable. - # DS_ACCELERATOR = 'cuda'|'xpu'|'cpu' + DS_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps'] if "DS_ACCELERATOR" in os.environ.keys(): accelerator_name = os.environ["DS_ACCELERATOR"] if accelerator_name == "xpu": @@ -65,7 +65,11 @@ def get_accelerator(): except ImportError as e: raise ValueError( f"CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.") - elif accelerator_name == "cuda": + elif accelerator_name == "npu": + try: + import torch_npu # noqa: F401 + except ImportError as e: + raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.") pass elif accelerator_name == "mps": try: @@ -75,41 +79,58 @@ def get_accelerator(): torch.mps.current_allocated_memory() except (RuntimeError, ImportError) as e: raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.") + elif accelerator_name == "cuda": + pass else: raise ValueError( - f'DS_ACCELERATOR must be one of "cuda", "cpu", or "xpu". Value "{accelerator_name}" is not supported') + f'DS_ACCELERATOR must be one of {DS_ACCELERATOR_LIST}. Value "{accelerator_name}" is not supported') ds_set_method = "override" # 2. If no override, detect which accelerator to use automatically if accelerator_name == None: + # We need a way to choose among different accelerator types. + # Currently we detect which accelerator extension is installed + # in the environment and use it if the installing answer is True. + # An alternative might be detect whether CUDA device is installed on + # the system but this comes with two pitfalls: + # 1. the system may not have torch pre-installed, so + # get_accelerator().is_available() may not work. + # 2. Some scenario like install on login node (without CUDA device) + # and run on compute node (with CUDA device) may cause mismatch + # between installation time and runtime. + try: from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811 accelerator_name = "xpu" except ImportError as e: - # We need a way to choose between CUDA_Accelerator and CPU_Accelerator - # Currently we detect whether intel_extension_for_pytorch is installed - # in the environment and use CPU_Accelerator if the answer is True. - # An alternative might be detect whether CUDA device is installed on - # the system but this comes with two pitfalls: - # 1. the system may not have torch pre-installed, so - # get_accelerator().is_available() may not work. - # 2. Some scenario like install on login node (without CUDA device) - # and run on compute node (with CUDA device) may cause mismatch - # between installation time and runtime. + pass + if accelerator_name == None: try: import intel_extension_for_pytorch # noqa: F401,F811 accelerator_name = "cpu" except ImportError as e: - try: - import torch.mps - - # should use torch.mps.is_available() if it exists someday but this is used as proxy - torch.mps.current_allocated_memory() - accelerator_name = "mps" - except (RuntimeError, ImportError) as e: - accelerator_name = "cuda" + pass + if accelerator_name == None: + try: + import torch_npu # noqa: F401,F811 + + accelerator_name = "npu" + except ImportError as e: + pass + if accelerator_name == None: + try: + import torch.mps + + # should use torch.mps.is_available() if it exists someday but this is used as proxy + torch.mps.current_allocated_memory() + accelerator_name = "mps" + except (RuntimeError, ImportError) as e: + pass + if accelerator_name == None: + accelerator_name = "cuda" + ds_set_method = "auto detect" # 3. Set ds_accelerator accordingly @@ -124,6 +145,10 @@ def get_accelerator(): elif accelerator_name == "xpu": # XPU_Accelerator is already imported in detection stage ds_accelerator = XPU_Accelerator() + elif accelerator_name == "npu": + from .npu_accelerator import NPU_Accelerator + + ds_accelerator = NPU_Accelerator() elif accelerator_name == "mps": from .mps_accelerator import MPS_Accelerator diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index bec1de1e24c3..6dbb7f9dcc55 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -50,6 +50,7 @@ nccl_backend = None mpi_backend = None ccl_backend = None +hccl_backend = None # This should be set here so all rank/size information from the launcher can be propagated from deepspeed.comm.utils import * @@ -144,6 +145,7 @@ def init_deepspeed_backend(ds_backend, timeout, init_method): global nccl_backend global mpi_backend global ccl_backend + global hccl_backend rank = int(os.getenv('RANK', '-1')) size = int(os.getenv('WORLD_SIZE', '-1')) @@ -157,6 +159,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method): elif ds_backend == CCL_BACKEND: ccl_backend = CCLBackend(rank=rank, world_size=size, timeout=timeout, init_method=init_method) utils.logger.info(f"Initialize {ds_backend} backend") + elif ds_backend == HCCL_BACKEND: + utils.logger.warn("HCCL backend in DeepSpeed not yet implemented") else: utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend") @@ -196,6 +200,7 @@ def set_backend(): global nccl_backend global mpi_backend global ccl_backend + global hccl_backend backend_name = get_accelerator().communication_backend_name() @@ -208,6 +213,9 @@ def set_backend(): elif backend_name == CCL_BACKEND: if ccl_backend is not None and ccl_backend.is_initialized(): cdb = ccl_backend + elif backend_name == HCCL_BACKEND: + if hccl_backend is not None and hccl_backend.is_initialized(): + cdb = hccl_backend @timed_op diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py index d2a79db95975..50d234c93fa0 100644 --- a/deepspeed/comm/constants.py +++ b/deepspeed/comm/constants.py @@ -8,6 +8,7 @@ MPI_BACKEND = 'mpi' GLOO_BACKEND = 'gloo' SCCL_BACKEND = 'sccl' +HCCL_BACKEND = 'hccl' DEFAULT_AML_MASTER_PORT = "54965" DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo" diff --git a/op_builder/npu/__init__.py b/op_builder/npu/__init__.py new file mode 100644 index 000000000000..3efcc859aaa8 --- /dev/null +++ b/op_builder/npu/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +# NPU related operators will be added in the future. + +from .no_impl import NotImplementedBuilder diff --git a/op_builder/npu/builder.py b/op_builder/npu/builder.py new file mode 100644 index 000000000000..3ac59ac4defa --- /dev/null +++ b/op_builder/npu/builder.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class NPUOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + + cpp_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + return cpp_ext + + def cxx_args(self): + return [] + + def libraries_args(self): + return [] diff --git a/op_builder/npu/no_impl.py b/op_builder/npu/no_impl.py new file mode 100644 index 000000000000..f17973fda401 --- /dev/null +++ b/op_builder/npu/no_impl.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class NotImplementedBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on NPU backend.") + + def sources(self): + return []