From 3e70a8871537e32e15925e446e9adfebbc7fe679 Mon Sep 17 00:00:00 2001 From: CurryRice233 Date: Wed, 18 Oct 2023 11:57:51 +0800 Subject: [PATCH] Add NPU FusedAdam support (#4343) * add npu support dtypes * add npu fused_adam support * add license * Update accelerator/npu_accelerator.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update op_builder/npu/fused_adam.py Co-authored-by: Hz, Ji * Update accelerator/npu_accelerator.py Co-authored-by: Hz, Ji * Update accelerator/npu_accelerator.py Co-authored-by: Hz, Ji --------- Co-authored-by: jializheng Co-authored-by: Olatunji Ruwase Co-authored-by: Hz, Ji --- accelerator/npu_accelerator.py | 44 +++++++++++++------- op_builder/npu/__init__.py | 2 +- op_builder/npu/fused_adam.py | 74 ++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 15 deletions(-) create mode 100644 op_builder/npu/fused_adam.py diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 9776c74bc422..7ebbd320bb15 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -2,16 +2,28 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +import importlib +import inspect -import torch from .abstract_accelerator import DeepSpeedAccelerator +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch.npu +except ImportError: + pass class NPU_Accelerator(DeepSpeedAccelerator): def __init__(self): + super().__init__() self._name = 'npu' self._communication_backend_name = 'hccl' + # dict that holds class name <--> class type mapping i.e. + # 'AsyncIOBuilder': + # this dict will be filled at init stage + self.class_dict = None def is_synchronized_device(self): return False @@ -216,25 +228,29 @@ def op_builder_dir(self): except ImportError: return "deepspeed.ops.op_builder.npu" + def _lazy_init_class_dict(self): + if self.class_dict: + return + + op_builder_module = importlib.import_module(self.op_builder_dir()) + + # get op builder class from op_builder/npu/__init__.py + self.class_dict = {} + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): + self.class_dict[class_name] = class_obj + # create an instance of op builder and return, name specified by class_name def create_op_builder(self, class_name): builder_class = self.get_op_builder(class_name) - if builder_class != None: - return builder_class() - return None + return None if builder_class is None else builder_class() # return an op builder class, name specified by class_name def get_op_builder(self, class_name): - try: - # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed - # if successful this also means we're doing a local install and not JIT compile path - from op_builder import __deepspeed__ # noqa: F401 # type: ignore - from op_builder.npu import NotImplementedBuilder - except ImportError: - from deepspeed.ops.op_builder.npu import NotImplementedBuilder - - # return a NPUNotImplementedBuilder to avoid get NoneType[Name] in unit tests - return NotImplementedBuilder + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None def build_extension(self): from torch.utils.cpp_extension import BuildExtension diff --git a/op_builder/npu/__init__.py b/op_builder/npu/__init__.py index 3efcc859aaa8..0d9e76b3903f 100644 --- a/op_builder/npu/__init__.py +++ b/op_builder/npu/__init__.py @@ -5,5 +5,5 @@ '''Copyright The Microsoft DeepSpeed Team''' # NPU related operators will be added in the future. - +from .fused_adam import FusedAdamBuilder from .no_impl import NotImplementedBuilder diff --git a/op_builder/npu/fused_adam.py b/op_builder/npu/fused_adam.py new file mode 100644 index 000000000000..fc1bc83c7cc7 --- /dev/null +++ b/op_builder/npu/fused_adam.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + +try: + import torch_npu +except ImportError as e: + pass + + +class NPUFusedAdam: + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + bias_correction1 = beta1**step + bias_correction2 = beta2**step + + # iteration group['params'] + for i in range(len(tensor_lists[0])): + grad_flat = tensor_lists[0][i] + param_flat = tensor_lists[1][i] + m_flat = tensor_lists[2][i] + v_flat = tensor_lists[3][i] + + if adam_w_mode: + param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + epsilon, + grad_flat, + None, # max_grad_norm + False, # amsgrad + False, # maximize + out=(param_flat.data, m_flat, v_flat)) + else: + param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam( + bias_correction1, + bias_correction2, + lr, + beta1, + beta2, + epsilon, + grad_flat, + False, # use_locking + False, # use_nesterov + out=(param_flat.data, m_flat, v_flat)) + + +class FusedAdamBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] + + def load(self, verbose=True): + return NPUFusedAdam