-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into mrwyattii/pydantic-2-support
- Loading branch information
Showing
14 changed files
with
525 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#include <ipex.h> | ||
#include <torch/extension.h> | ||
#include <iostream> | ||
#include <sycl/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace xpu; | ||
|
||
void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1) | ||
{ | ||
// get the sign bit of each float and pack them into byte | ||
int i = item_ct1; | ||
for (int j = 0; j < 8; ++j) { | ||
int k = i * 8 + j; | ||
int bit = k < input_size && (!sycl::signbit(input[k])); | ||
output[i] |= bit << (7 - j); | ||
} | ||
} | ||
|
||
void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1) | ||
{ | ||
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1 | ||
int i = item_ct1; | ||
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2; | ||
} | ||
|
||
sycl::queue get_current_queue(at::Device device) | ||
{ | ||
c10::impl::VirtualGuardImpl impl(device.type()); | ||
c10::Stream _stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); | ||
sycl::queue queue = xpu::get_queue_from_stream(_stream); | ||
return queue; | ||
} | ||
|
||
/* | ||
pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8 | ||
if float x >= 0, will be packed as a '1' bit, or will be packed as '0' | ||
Arguments: | ||
tensor: A bool tensor that get packed. | ||
input_size: numel of input tensor | ||
rank: device id in order to get corresponding stream | ||
*/ | ||
at::Tensor packbits(at::Tensor tensor, int input_size, int rank) | ||
{ | ||
at::Device device = "xpu:" + std::to_string(rank); | ||
sycl::queue q = get_current_queue(device); | ||
|
||
int packed_size = (input_size + 7) / 8; | ||
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU); | ||
at::Tensor packed = torch::zeros({packed_size}, unit8_options); | ||
|
||
float* input = (float*)tensor.data_ptr(); | ||
uint8_t* output = (uint8_t*)packed.data_ptr(); | ||
|
||
auto event = q.submit([&](sycl::handler& cgh) { | ||
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) { | ||
packbitskernel(input, output, input_size, item_ct1); | ||
}); | ||
}); | ||
|
||
return packed; | ||
} | ||
|
||
/* | ||
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float | ||
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1). | ||
Arguments: | ||
tensor: A uint8 tensor that get unpacked. | ||
input_size: numel of input tensor | ||
rank: device id in order to get corresponding stream | ||
*/ | ||
at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank) | ||
{ | ||
at::Device device = "xpu:" + std::to_string(rank); | ||
sycl::queue q = get_current_queue(device); | ||
|
||
auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU); | ||
at::Tensor unpacked = torch::empty({input_size * 8}, float_options); | ||
|
||
uint8_t* input = (uint8_t*)tensor.data_ptr(); | ||
float* output = (float*)unpacked.data_ptr(); | ||
|
||
auto event = q.submit([&](sycl::handler& cgh) { | ||
cgh.parallel_for<>(range(input_size * 8), | ||
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); }); | ||
}); | ||
|
||
return unpacked; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)"); | ||
m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import numpy as np | ||
import torch | ||
import deepspeed.comm as dist | ||
from deepspeed.accelerator import get_accelerator | ||
from deepspeed.ops.op_builder import PackbitsBuilder | ||
|
||
|
||
class CompressedBackend(object): | ||
|
||
def __init__(self, mpu=None): | ||
if mpu is None: | ||
self.world_group = dist.new_group(ranks=range(dist.get_world_size())) | ||
else: | ||
self.mpu = mpu | ||
self.world_group = self.mpu.get_data_parallel_group() | ||
self.size = dist.get_world_size(group=self.world_group) | ||
self.rank = dist.get_rank(group=self.world_group) | ||
self.packer = PackbitsBuilder().load() | ||
|
||
def my_igather(self, rank, size, group, sendbuf, recvbuf, root): | ||
req = [] | ||
if rank == root: | ||
for idx in range(size): | ||
if idx != rank: | ||
req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) | ||
else: | ||
recvbuf[rank] = sendbuf | ||
else: | ||
req.append(dist.isend(sendbuf, group=group, dst=root)) | ||
return req | ||
|
||
def my_gather(self, rank, size, group, sendbuf, recvbuf, root): | ||
if rank == root: | ||
for idx in range(size): | ||
if idx != rank: | ||
dist.recv(recvbuf[idx], src=idx, group=group) | ||
else: | ||
recvbuf[rank] = sendbuf | ||
else: | ||
dist.send(sendbuf, group=group, dst=root) | ||
|
||
def pack(self, buffer, size): | ||
# pack float tensor into uint8 tensor | ||
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank) | ||
return packed.reshape(size, -1) | ||
|
||
def unpack(self, buffer, size, dtype): | ||
# unpack uint8 to float tensor | ||
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank) | ||
return unpacked.reshape(size, -1).to(dtype) | ||
|
||
def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): | ||
original_shape = buffer_m.size() | ||
if len(original_shape) > 1: | ||
buffer_m = torch.flatten(buffer_m) | ||
|
||
# align size of original_buffer and error | ||
original_size = buffer_m.numel() | ||
worker_error_size = worker_error.numel() | ||
if original_size != worker_error_size: | ||
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) | ||
buffer_m = torch.cat([buffer_m, empty_tensor]) | ||
|
||
buffer_m.add_(worker_error) | ||
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) | ||
|
||
worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) | ||
|
||
sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8) | ||
|
||
recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], | ||
dtype=sign_list_packed_tmp[0].dtype, | ||
device=sign_list_packed_tmp.device) | ||
|
||
sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] | ||
|
||
recvbuf_scale = [ | ||
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name()) | ||
for _ in range(self.size) | ||
] | ||
|
||
# communication phase 1 | ||
# all to all for sign | ||
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) | ||
# all gather for scale | ||
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) | ||
|
||
flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() | ||
compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \ | ||
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) | ||
|
||
compensated_server_m.add_(server_error) | ||
|
||
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) | ||
|
||
server_error.set_(compensated_server_m - | ||
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) | ||
|
||
server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8) | ||
|
||
# recvbuf_sign_server | ||
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], | ||
dtype=recvbuf_sign.dtype, | ||
device=server_sign_packed.device) | ||
|
||
recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] | ||
|
||
# recvbuf_scale_server | ||
recvbuf_scale_server_tmp = torch.zeros([self.size, 1], | ||
dtype=worker_scale.dtype, | ||
device=server_sign_packed.device) | ||
|
||
recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] | ||
|
||
# communication Phase 2 | ||
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) | ||
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) | ||
|
||
recvbuf_sign_server = torch.stack(recvbuf_sign_server) | ||
|
||
flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() | ||
|
||
buffer_m.data.copy_( | ||
self.unpack(flattened_recvbuf_sign_server, self.size, | ||
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) | ||
|
||
if original_size != worker_error_size: | ||
buffer_m = buffer_m[0:original_size] | ||
if len(original_shape) > 1: | ||
buffer_m = buffer_m.reshape(original_shape) | ||
|
||
return buffer_m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.