Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
CurryRice233 authored Oct 14, 2023
2 parents fb77cd0 + 4fc181b commit 38bc926
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 18 deletions.
59 changes: 48 additions & 11 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,20 @@ int world_size = -1;

std::set<int> _comm_ids;
std::set<int> _colors;
ccl::vector_class<ccl::communicator> _ccl_comms;
std::vector<ccl::communicator> _ccl_comms;
ccl::shared_ptr_class<ccl::kvs> sub_kvs;
std::map<std::vector<int>, int> group_to_comm_id;

ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(std::vector<int> ranks)
{
if (group_to_comm_id.find(ranks) != group_to_comm_id.end()) {
auto id = group_to_comm_id.find(ranks);
return _ccl_comms[id->second];
}
return _ccl_comms[0];
}

#define CCLCHECK(cmd) \
do { \
Expand Down Expand Up @@ -394,12 +404,29 @@ int next_unique_val(std::set<int> s)
}
}

py::object new_group(std::vector<int> ranks)
std::vector<uint8_t> get_sub_kvs_addr(bool first)
{
if (first) {
sub_kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = sub_kvs->get_address();
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
} else {
ccl::kvs::address_type main_addr;
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
}
}

void initialize_sub_comm(int size, int rank, torch::Tensor& kvs_data, std::vector<int> ranks)
{
int comm_id = next_unique_val(_comm_ids);
int color = next_unique_val(_colors);
std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
<< std::endl;
ccl::kvs::address_type main_addr;
if (rank != 0) {
memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
sub_kvs = ccl::create_kvs(main_addr);
}
_ccl_comms.push_back(ccl::create_communicator(size, rank, sub_kvs));
group_to_comm_id[ranks] = _ccl_comms.size() - 1;
}

ccl::datatype get_ccl_datatype(c10::ScalarType type)
Expand Down Expand Up @@ -452,7 +479,7 @@ ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
return ccl_op;
}

void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
void broadcast(torch::Tensor& data, int src, std::vector<int> group, bool async_op)
{
CCLCHECK(ccl::broadcast(data.data_ptr(),
data.numel(),
Expand All @@ -463,7 +490,7 @@ void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
}

// TODO: implement torch's async_op behavior, document it.
void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
void all_reduce(torch::Tensor& data, py::object op, std::vector<int> group, bool async_op)
{
CCLCHECK(ccl::allreduce(data.data_ptr(),
data.data_ptr(),
Expand All @@ -477,7 +504,7 @@ void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async
void all_reduce_caching(torch::Tensor& data,
py::object op,
std::string match_id,
py::object group,
std::vector<int> group,
bool async_op)
{
ccl::allreduce_attr attr = ccl::default_allreduce_attr;
Expand Down Expand Up @@ -510,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes)
}
}

void inference_all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
void inference_all_reduce(torch::Tensor& data, py::object op, std::vector<int> group, bool async_op)
{
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
Expand Down Expand Up @@ -583,11 +610,18 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group,
}
}

void barrier(py::object group, bool async_op)
void barrier(std::vector<int> group, bool async_op)
{
CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
}

std::vector<std::string> get_available_coll()
{
std::vector<std::string> colls{
"broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"};
return colls;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
Expand All @@ -599,4 +633,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
m.def("barrier", &barrier, "barrier");
m.def("initialize_sub_comm", &initialize_sub_comm, "initialize_sub_comm");
m.def("get_sub_kvs_addr", &get_sub_kvs_addr, "get_sub_kvs_addr");
m.def("get_available_coll", &get_available_coll, "get_available_coll");
}
126 changes: 119 additions & 7 deletions deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def build_ccl_op():
return ccl_cpp_module


class CCLHandler():

def __init__(self, ccl_comm_op=None):
self.ccl_comm_op = ccl_comm_op

def wait(self):
# backend covered it
pass


class CCLBackend(TorchBackend):

def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
Expand All @@ -39,27 +49,129 @@ def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, i
size = self.get_world_size()
rank = self.get_rank()
main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
main_kvs = torch.tensor(main_kvs).to(torch.uint8)
main_kvs = torch.tensor(main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
super(CCLBackend, self).broadcast(main_kvs, 0)
self.ccl_comm_op.initialize(size, rank, main_kvs)
self.initialized = True
self.groups = [tuple(range(self.get_world_size()))]
self.available_coll = self.ccl_comm_op.get_available_coll()

def is_initialized(self):
return self.initialized

def broadcast(self, tensor, src, group=None, async_op=False):
self.ccl_comm_op.broadcast(tensor, src, group, async_op)
def run_collective(self, name, **kwargs):
if name in self.available_coll:
kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
if 'dst' in kwargs:
kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
if 'src' in kwargs:
kwargs['src'] = kwargs['group'].index(kwargs['src'])
func = "self.ccl_comm_op." + name
eval(func)(*(kwargs.values()))
return CCLHandler(self.ccl_comm_op)
else:
func = "super(CCLBackend, self)." + name
return eval(func)(*(kwargs.values()))

def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
return self.run_collective(name="all_reduce_caching",
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
else:
self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
self.ccl_comm_op.inference_all_reduce(tensor, op, group, async_op)
return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)

def broadcast(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)

def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return self.run_collective(name="all_gather",
tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)

def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False):
return self.run_collective(name="reduce_scatter_tensor",
output_tensor=output_tensor,
input_tensor=input_tensor,
op=op,
group=group)

def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
return self.run_collective(name="all_gather_into_tensor",
output_tensor=output_tensor,
input_tensor=input_tensor,
group=group)

def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False):
return self.run_collective(name="all_to_all_single",
output=output,
input=input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group)

def send(self, tensor, dst, group=None, async_op=False):
return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op)

def recv(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op)

def gather(self, tensor, gather_list, dst, group=None, async_op=False):
return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group)

def scatter(self, tensor, gather_list, dst, group=None, async_op=False):
return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group)

def barrier(self, group=None, async_op=False):
self.ccl_comm_op.barrier(group, async_op)
return self.run_collective(name="barrier", group=group, async_op=async_op)

def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False):
return self.run_collective(name="monitored_barrier", group=group)

def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="reduce_scatter",
output=output,
input_list=input_list,
op=op,
group=group,
async_op=async_op)

def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)

def new_group(self, ranks):
return super(CCLBackend, self).new_group(ranks)

def _new_group(self, ranks, group):
size = len(ranks)
rank = self.get_rank()
sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0])
sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group)
self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks)
self.groups.append(tuple(ranks))

def get_all_ranks_from_group(self, group):
if group is None:
return list(range(self.get_world_size()))
rank = 0
results = []
try:
while True:
results.append(super(CCLBackend, self).get_global_rank(group, rank))
rank += 1
except RuntimeError:
pass
if tuple(results) not in self.groups:
self._new_group(results, group)
return results

0 comments on commit 38bc926

Please sign in to comment.