Skip to content

Commit

Permalink
Add NPU FusedAdam support (#4343)
Browse files Browse the repository at this point in the history
* add npu support dtypes

* add npu fused_adam support

* add license

* Update accelerator/npu_accelerator.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update op_builder/npu/fused_adam.py

Co-authored-by: Hz, Ji <[email protected]>

* Update accelerator/npu_accelerator.py

Co-authored-by: Hz, Ji <[email protected]>

* Update accelerator/npu_accelerator.py

Co-authored-by: Hz, Ji <[email protected]>

---------

Co-authored-by: jializheng <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Hz, Ji <[email protected]>
  • Loading branch information
4 people authored Oct 18, 2023
1 parent 2cbfb89 commit 3e70a88
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 15 deletions.
44 changes: 30 additions & 14 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import importlib
import inspect

import torch
from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch.npu
except ImportError:
pass


class NPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
super().__init__()
self._name = 'npu'
self._communication_backend_name = 'hccl'
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
self.class_dict = None

def is_synchronized_device(self):
return False
Expand Down Expand Up @@ -216,25 +228,29 @@ def op_builder_dir(self):
except ImportError:
return "deepspeed.ops.op_builder.npu"

def _lazy_init_class_dict(self):
if self.class_dict:
return

op_builder_module = importlib.import_module(self.op_builder_dir())

# get op builder class from op_builder/npu/__init__.py
self.class_dict = {}
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
self.class_dict[class_name] = class_obj

# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, class_name):
builder_class = self.get_op_builder(class_name)
if builder_class != None:
return builder_class()
return None
return None if builder_class is None else builder_class()

# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.npu import NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.npu import NotImplementedBuilder

# return a NPUNotImplementedBuilder to avoid get NoneType[Name] in unit tests
return NotImplementedBuilder
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
Expand Down
2 changes: 1 addition & 1 deletion op_builder/npu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
'''Copyright The Microsoft DeepSpeed Team'''

# NPU related operators will be added in the future.

from .fused_adam import FusedAdamBuilder
from .no_impl import NotImplementedBuilder
74 changes: 74 additions & 0 deletions op_builder/npu/fused_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .builder import NPUOpBuilder

try:
import torch_npu
except ImportError as e:
pass


class NPUFusedAdam:

@staticmethod
def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode,
bias_correction, weight_decay, *args):
bias_correction1 = beta1**step
bias_correction2 = beta2**step

# iteration group['params']
for i in range(len(tensor_lists[0])):
grad_flat = tensor_lists[0][i]
param_flat = tensor_lists[1][i]
m_flat = tensor_lists[2][i]
v_flat = tensor_lists[3][i]

if adam_w_mode:
param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam_w(
bias_correction1,
bias_correction2,
lr,
weight_decay,
beta1,
beta2,
epsilon,
grad_flat,
None, # max_grad_norm
False, # amsgrad
False, # maximize
out=(param_flat.data, m_flat, v_flat))
else:
param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam(
bias_correction1,
bias_correction2,
lr,
beta1,
beta2,
epsilon,
grad_flat,
False, # use_locking
False, # use_nesterov
out=(param_flat.data, m_flat, v_flat))


class FusedAdamBuilder(NPUOpBuilder):
BUILD_VAR = "DS_BUILD_FUSED_ADAM"
NAME = "fused_adam"

def __init__(self):
super().__init__(name=self.NAME)

def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'

def sources(self):
return []

def include_paths(self):
return []

def load(self, verbose=True):
return NPUFusedAdam

0 comments on commit 3e70a88

Please sign in to comment.