Skip to content

Commit

Permalink
[NPU] Add HcclBackend for 1-bit adam, 1-bit lamb, 0/1 adam (microsoft…
Browse files Browse the repository at this point in the history
…#4733)

To support NPU devices fulfilling feature requirements like 1-bit Adam,
1-bit Lamb, 0/1 Adam, I add HcclBackend and its corresponding import
logics.
See what we have already done in microsoft#4567 .

---------

Co-authored-by: ryan <[email protected]>
Co-authored-by: Conglong Li <[email protected]>
  • Loading branch information
3 people authored and amaurya committed Feb 17, 2024
1 parent e1ea991 commit 385448e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 18 deletions.
124 changes: 124 additions & 0 deletions deepspeed/runtime/comm/hccl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import numpy as np
import torch
import torch_npu
import deepspeed.comm as dist


class HcclBackend(object):

def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.size = dist.get_world_size(group=self.world_group)
self.rank = dist.get_rank(group=self.world_group)

def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
else:
recvbuf[rank] = sendbuf
else:
req.append(dist.isend(sendbuf, group=group, dst=root))
return req

def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
if rank == root:
for idx in range(size):
if idx != rank:
dist.recv(recvbuf[idx], src=idx, group=group)
else:
recvbuf[rank] = sendbuf
else:
dist.send(sendbuf, group=group, dst=root)

def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
original_shape = buffer_m.size()
if len(original_shape) > 1:
buffer_m = torch.flatten(buffer_m)

# align size of original_buffer and error
original_size = buffer_m.numel()
worker_error_size = worker_error.numel()
if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])

buffer_m.add_(worker_error)
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))

worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

sign_list_packed_tmp = torch_npu.npu_sign_bits_pack(buffer_m, self.size).type(torch.int8)

recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])],
dtype=sign_list_packed_tmp[0].dtype,
device=sign_list_packed_tmp.device)

sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)]

recvbuf_scale = [
torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(local_rank)) for _ in range(self.size)
]

# communication phase 1
# all to all for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
# all gather for scale
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten()
compensated_server_m = torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign, self.size, torch.float32) \
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)

compensated_server_m.add_(server_error)

server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

server_sign_packed = torch_npu.npu_sign_bits_pack(compensated_server_m, 1).type(torch.int8)

# recvbuf_sign_server
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])],
dtype=recvbuf_sign.dtype,
device=server_sign_packed.device)

recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)]

# recvbuf_scale_server
recvbuf_scale_server_tmp = torch.zeros([self.size, 1],
dtype=worker_scale.dtype,
device=server_sign_packed.device)

recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)]

# communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)

recvbuf_sign_server = torch.stack(recvbuf_sign_server)

flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten()

buffer_m.data.copy_(
torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign_server, self.size,
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)

if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1:
buffer_m = buffer_m.reshape(original_shape)

return buffer_m
11 changes: 5 additions & 6 deletions deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def __init__(self,

super(OnebitAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
assert (dist.is_initialized())

self.comm_time = 0.0
self.step_time = 0.0
self.ave_step = 1
Expand All @@ -86,22 +84,23 @@ def __init__(self,

self.comm_backend_name = comm_backend_name

assert dist.is_initialized(), "Please initialize the torch distributed backend."
# Empty initializer. Set handle based on the comm backend as follows.
self.comm_backend_handle = None

if self.comm_backend_name == 'nccl':
assert (
required_torch_version(min_version=1.8)
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
self.comm_backend_handle = MpiBackend(cuda_aware)

elif self.comm_backend_name == 'hccl':
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
self.size = self.comm_backend_handle.size

self.divider = int(self.size * 8 / np.gcd(self.size, 8))
Expand Down
12 changes: 6 additions & 6 deletions deepspeed/runtime/fp16/onebit/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def __init__(self,

super(OnebitLamb, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
assert (dist.is_initialized())

self.deepspeed = deepspeed
self.lamb_freeze_key = False
self.initialize = False
Expand All @@ -108,21 +106,23 @@ def __init__(self,

self.comm_backend_name = comm_backend_name

assert dist.is_initialized(), "Please initialize the torch distributed backend."
# Empty initializer. Set handle based on the comm backend as follows.
self.comm_backend_handle = None

if self.comm_backend_name == 'nccl':
assert (
required_torch_version(min_version=1.8)
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
self.comm_backend_handle = MpiBackend(cuda_aware)
elif self.comm_backend_name == 'hccl':
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)

self.size = self.comm_backend_handle.size

Expand Down Expand Up @@ -161,7 +161,7 @@ def step(self, closure=None, grads=None):
else:
grads_group = grads

#remove the previous stats
# remove the previous stats
del self.lamb_coeffs[:]

if self.lamb_freeze_key:
Expand Down
11 changes: 5 additions & 6 deletions deepspeed/runtime/fp16/onebit/zoadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def __init__(self,

super(ZeroOneAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
assert (dist.is_initialized())

self.deepspeed = deepspeed
self.initialize = False
self.cuda_aware = cuda_aware
Expand All @@ -99,22 +97,23 @@ def __init__(self,

self.comm_backend_name = comm_backend_name

assert dist.is_initialized(), "Please initialize the torch distributed backend."
# Empty initializer. Set handle based on the comm backend as follows.
self.comm_backend_handle = None

if self.comm_backend_name == 'nccl':
assert (
required_torch_version(min_version=1.8)
), "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
self.comm_backend_handle = MpiBackend(cuda_aware)

elif self.comm_backend_name == 'hccl':
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
self.size = self.comm_backend_handle.size

self.divider = int(self.size * 8 / np.gcd(self.size, 8))
Expand Down

0 comments on commit 385448e

Please sign in to comment.