From b22706a7211366abf2df98a0d118ea1d3a837e21 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 18 Apr 2024 02:52:36 +0800 Subject: [PATCH] [CPU] Support SHM based inference_all_reduce in TorchBackend (#5391) This PR adds SHM based `inference_all_reduce` kernel to `TorchBackend` communication backend. When inference on CPU server, this path replaces default `torch.distributed.all_reduce` which eventurally use gloo backend. This PR will improve inference performance with AutoTP when only stock PyTorch is installed without Intel Extension for PyTorch. Compared with gloo backend. SHM based inference_all_reduce kernel is a more directed path and perform much better on single node. | message size | gloo all_reduce(ms) | SHM all_reduce(ms) | | --- | --- | --- | | 32MB | 30.7 | 0.65 | | 64KB | 0.23 | 0.028 | In text generation of bloom-3b with AutoTP, average token latency improved 1.45x with this PR on 2S Xeon node. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- accelerator/cpu_accelerator.py | 6 +- csrc/cpu/comm/ccl.cpp | 2 +- csrc/cpu/comm/shm_interface.cpp | 120 ++++++++++++++++++++++++++++++++ deepspeed/comm/ccl.py | 6 +- deepspeed/comm/comm.py | 2 +- deepspeed/comm/torch.py | 20 +++++- op_builder/cpu/__init__.py | 2 +- op_builder/cpu/comm.py | 27 +++++++ 8 files changed, 174 insertions(+), 11 deletions(-) create mode 100644 csrc/cpu/comm/shm_interface.cpp diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 870d3e91816e..a0171723cfb8 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -300,12 +300,14 @@ def get_op_builder(self, class_name): # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path from op_builder import __deepspeed__ # noqa: F401 # type: ignore - from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder except ImportError: - from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder if class_name == "CCLCommBuilder": return CCLCommBuilder + elif class_name == "ShareMemCommBuilder": + return ShareMemCommBuilder elif class_name == "FusedAdamBuilder": return FusedAdamBuilder elif class_name == "CPUAdamBuilder": diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index 786906717f23..d28509e59266 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -247,7 +247,7 @@ void all_reduce_caching(torch::Tensor& data, .wait()); } -void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op) +void inference_all_reduce(torch::Tensor& data, py::object op) { #ifdef DO_PROFILE static double total_time = 0.0; diff --git a/csrc/cpu/comm/shm_interface.cpp b/csrc/cpu/comm/shm_interface.cpp new file mode 100644 index 000000000000..981ea36515fc --- /dev/null +++ b/csrc/cpu/comm/shm_interface.cpp @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "shm.h" + +// #define DO_PROFILE +#ifdef DO_PROFILE +#include +#include +#endif + +// Communication settings +static int world_rank = -1; +static int world_size = -1; + +static bool is_initialized = 0; + +static bool all_ranks_local_p = false; + +void initialize(int size, int rank) +{ + if (is_initialized) return; + + // Check whether all ranks is on the same physical machine. + // If true, we will use an SHM based low latency allreduce + + auto ls_string = std::getenv("LOCAL_SIZE"); + int ls = 0; + if (ls_string != NULL) { ls = std::stoi(std::getenv("LOCAL_SIZE")); } + + if (size >= 1 && size == ls) { all_ranks_local_p = true; } + + world_size = size; + world_rank = rank; + is_initialized = 1; + + auto addr_string = std::getenv("MASTER_ADDR"); + if (addr_string == NULL) { addr_string = ""; } + auto port_string = std::getenv("MASTER_PORT"); + if (port_string == NULL) { port_string = ""; } + + if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); } +} + +int get_rank(int group = 0) { return world_rank; } + +int get_world_size(int group = 0) { return world_size; } + +// Success - return 0 +// Fail (cannot hornor the request and need to fall back) - return -1 +int inference_all_reduce(torch::Tensor& data, py::object op) +{ + if (!all_ranks_local_p) return -1; +#ifdef DO_PROFILE + static double total_time = 0.0; + static double total_time_sq = 0.0; + static int count = -16; // warmup + static double max_time = 0.0; + static double min_time = DBL_MAX; + // make sure all rank reach this point before measuring time + // turn on this if you suspect each rank didn't reach here at the same time (stragger) + // if (all_ranks_local_p) { barrier_wait(0, world_size); } + auto start = std::chrono::system_clock::now(); +#endif + + static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); + static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); + + assert(py::int_(op.attr("value")) == ReduceOpSum); + + auto numel = data.numel(); + + int data_size = 0; + bool data_type_fallback = false; + + switch (data.scalar_type()) { + case c10::ScalarType::BFloat16: data_size = numel * 2; break; + case c10::ScalarType::Float: data_size = numel * 4; break; + default: data_type_fallback = true; + } + + if (data_type_fallback) return -1; + + all_reduce_outer_loop(data, numel, data_size); + +#ifdef DO_PROFILE + auto end = std::chrono::system_clock::now(); + count++; + if (count > 0) { + double elapsed = std::chrono::duration_cast(end - start).count(); + if (elapsed > max_time) { max_time = elapsed; } + if (elapsed < min_time) { min_time = elapsed; } + total_time += elapsed; + total_time_sq += elapsed * elapsed; + if (world_rank == 0 && count == 1000) { + auto avg = total_time / count; + auto sd = + sqrt(total_time_sq / count - total_time * total_time / (count * count)) / avg * 100; + printf(" C++ kernel\t\t %.2f\t %.2f\t%.2f\t %.2f\n", + min_time, + max_time, + total_time / count, + sd); + } + } +#endif + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("initialize", &initialize, "shm initialize"); + m.def("get_rank", &get_rank, "get rank"); + m.def("get_world_size", &get_world_size, "get world size"); + m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation"); +} diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py index a199843d5086..f2a6cb6a36dc 100644 --- a/deepspeed/comm/ccl.py +++ b/deepspeed/comm/ccl.py @@ -98,12 +98,12 @@ def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): else: return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op) - def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): + def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None): name = "inference_all_reduce" if name in self.available_coll: - return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op) + return self.ccl_comm_op.inference_all_reduce(tensor, op) else: - return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op) + return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=False) def broadcast(self, tensor, src, group=None, async_op=False): return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index c92ae74375f1..4f826edab3d6 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -505,7 +505,7 @@ def inference_all_reduce(tensor, log_name='all_reduce', debug=get_caller_func()): global cdb - return cdb.inference_all_reduce(tensor, op, group, async_op) + return cdb.inference_all_reduce(tensor, op, group) @timed_op diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 35c31f9fe2ec..14d4f3847315 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import deepspeed from deepspeed import utils from .utils import * @@ -19,6 +20,15 @@ DS_COMM_REDUCE_OFF = False +def build_shm_op(): + builder = get_accelerator().create_op_builder("ShareMemCommBuilder") + if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: + return None + shm_cpp_module = builder.load() + print(f'DeepSpeed {builder.absolute_name()} built successfully') + return shm_cpp_module + + def has_coalescing_manager(): has_c10d = hasattr(torch.distributed, 'distributed_c10d') return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager') @@ -88,6 +98,7 @@ class TorchBackend(Backend): def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'): super(TorchBackend, self).__init__() + self.shm_comm_op = build_shm_op() self.has_all_reduce_coalesced = has_all_reduce_coalesced() self.has_coalescing_manager = has_coalescing_manager() self.all_gather_function = self.get_all_gather_function() @@ -99,6 +110,8 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' # it is not so we can run on a single GPU without doing any init_process_group self.single_gpu_mode = True self.init_process_group(backend, timeout, init_method, rank, world_size) + if self.shm_comm_op != None: + self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) @classmethod @compiler.disable @@ -139,9 +152,10 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) @compiler.disable - def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): - op = self._reduce_op(op) - return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) + def inference_all_reduce(self, tensor, op, group=None): + if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1: + op = self._reduce_op(op) + return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False) @compiler.disable def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py index 0ecf2caf65ae..30238add3f90 100644 --- a/op_builder/cpu/__init__.py +++ b/op_builder/cpu/__init__.py @@ -4,7 +4,7 @@ # DeepSpeed Team '''Copyright The Microsoft DeepSpeed Team''' -from .comm import CCLCommBuilder +from .comm import CCLCommBuilder, ShareMemCommBuilder from .fused_adam import FusedAdamBuilder from .cpu_adam import CPUAdamBuilder from .no_impl import NotImplementedBuilder diff --git a/op_builder/cpu/comm.py b/op_builder/cpu/comm.py index 38e965530f43..69b9226fd806 100644 --- a/op_builder/cpu/comm.py +++ b/op_builder/cpu/comm.py @@ -42,3 +42,30 @@ def extra_ldflags(self): return [] else: return ['-lccl', f'-L{ccl_root_path}/lib'] + + +class ShareMemCommBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_SHM_COMM" + NAME = "deepspeed_shm_comm" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def sources(self): + return ['csrc/cpu/comm/shm_interface.cpp', 'csrc/cpu/comm/shm.cpp'] + + def include_paths(self): + includes = ['csrc/cpu/includes'] + return includes + + def cxx_args(self): + return ['-O2', '-fopenmp'] + + def is_compatible(self, verbose=True): + # TODO: add soft compatibility check for private binary release. + # a soft check, as in we know it can be trivially changed. + return super().is_compatible(verbose)