From db5a875b8de35e15293d572b9ea3527f1b73ede0 Mon Sep 17 00:00:00 2001 From: Avinash Maurya Date: Mon, 15 Jul 2024 12:56:27 -0400 Subject: [PATCH 1/8] Fix memory leak for pipelined optimizer swapper (#5700) We identified a memory leak when training with NVMe offloaded optimizer states. The issue occurs when `pipeline_write=true`, where the tensors that have swapped out and written to NVMe are not deallocated, leading to a memory leak. This PR resolves the issue by deallocating the unused tensors which have swapped out to NVMe. Co-authored-by: amaurya --- deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py index 03dfe663fcb6..66a372877d38 100644 --- a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py +++ b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py @@ -8,6 +8,7 @@ from deepspeed.ops.op_builder import AsyncIOBuilder from deepspeed import comm as dist +import torch from deepspeed.runtime.swap_tensor.constants import * from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object @@ -154,6 +155,8 @@ def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors): def _complete_swap_out(self, swap_out_type): self.swap_ops[swap_out_type].wait() + for buffer in self.swap_ops[swap_out_type].state_buffers: + buffer = torch.Tensor() self.swap_buffer_manager.free(self.swap_ops[swap_out_type].allocated_buffers) self.swap_ops[swap_out_type] = None From 0af9ac314fb3847b33e5b5f63a570ce6e1bdbca4 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 16 Jul 2024 00:57:14 +0800 Subject: [PATCH 2/8] Remove duplicated variable (#5727) Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/stage3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index edeedb691961..37b81d42c0d6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -338,7 +338,6 @@ def __init__( self.params_in_ipg_bucket = [] self.params_already_reduced = {} - self.is_gradient_accumulation_boundary = True self._release_ipg_buffers() self.previous_reduced_grads = None From a07a3c5d228bd4ee43b4c849f22b6753e94f6559 Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Tue, 16 Jul 2024 00:57:45 +0800 Subject: [PATCH 3/8] Fix phi3 mini 128k load error (#5765) Fix phi3 mini 128k load error. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/auto_tp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 1c76cbc0a6ef..6f4c170d8295 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -134,7 +134,7 @@ def is_load_module(module): load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", - "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding" + "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding" ] return module.__class__ in load_layers or module._get_name() in load_layer_names From ec6cbb3c08ba58b42b99de3e94c5acbf5bfa757a Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 16 Jul 2024 06:24:11 +0800 Subject: [PATCH 4/8] [CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph (#5604) This PR allows `deepspeed.comm.inference_all_reduce()` enters torch.compile graph even it is implemented as C++ kernel in DeepSpeed. Previous implementation register `inference_all_reduce()` C++ kernel as pybind function so it can be called inside PyThon code. However pybind function cannot be recognized by PyTorch so graph breaks when `inference_all_reduce` is called. We address issue by register `inference_all_reduce` as a PyTorch custom op `torch.ops.deepspeed.inference_all_reduce`, so it can be built into PyTorch graph The output trace code from torchinductor ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"): # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor) inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3) # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input) permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute); primals_2 = permute = None # No stacktrace found for following nodes copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce); primals_3 = None return [addmm, inference_all_reduce] ``` Note in this PR the inference_all_reduce op for CPU does not handle multinode and FP16 data type. For FP16 data type support, we will align with PyTorch CPU FP16 plan. For multinode, we are still looking at the possibility to upstream oneCCL integration into PyTorch, so we are able to get use of oneCCL for multinode tensor parallel inference with PyTorch. This PR is independent to https://github.com/microsoft/DeepSpeed/pull/5571. They can work seperately or together without issue. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- csrc/cpu/comm/shm_interface.cpp | 99 +++++++++++++++++++++++++++------ deepspeed/comm/torch.py | 5 +- 2 files changed, 85 insertions(+), 19 deletions(-) diff --git a/csrc/cpu/comm/shm_interface.cpp b/csrc/cpu/comm/shm_interface.cpp index d11c8cfa7375..5be5cb799a7b 100644 --- a/csrc/cpu/comm/shm_interface.cpp +++ b/csrc/cpu/comm/shm_interface.cpp @@ -46,15 +46,13 @@ void initialize(int size, int rank) if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); } } -int get_rank(int group = 0) { return world_rank; } - -int get_world_size(int group = 0) { return world_size; } +void inference_all_reduce_(torch::Tensor& data, int op); // Success - return 0 // Fail (cannot hornor the request and need to fall back) - return -1 -int inference_all_reduce(torch::Tensor& data, py::object op) +void inference_all_reduce_(torch::Tensor& data, int op) { - if (!all_ranks_local_p) return -1; + assert(op == 0); #ifdef DO_PROFILE static double total_time = 0.0; static double total_time_sq = 0.0; @@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op) auto start = std::chrono::system_clock::now(); #endif - static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); - static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); - - assert(py::int_(op.attr("value")) == ReduceOpSum); - auto numel = data.numel(); int data_size = 0; @@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op) default: data_type_fallback = true; } - if (data_type_fallback) return -1; + if (data_type_fallback) return; all_reduce_outer_loop(data, numel, data_size); @@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op) } } #endif - return 0; + return; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); } + +TORCH_LIBRARY(deepspeed, m) +{ + m.def("inference_all_reduce(Tensor self) -> Tensor"); + m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)"); +} + +torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_) +{ + torch::Tensor result_ = torch::empty_like(self_); + return result_; +} + +torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; } + +torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_) +{ + TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU); + torch::Tensor self_tensor = self_.contiguous(); + inference_all_reduce_(self_tensor, 0); + return self_; +} + +torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_) +{ + torch::Tensor result = self_.clone(); + inference_all_reduce__cpu(result); + return result; +} + +#include +// The boilerplate functionalization logic, that teaches functionalization +// how to map x_() calls into x() calls. +// Long term, we'd like to not require users to write this logic. +// HOWEVER, if you have a custom op that is mutable, +// You will still need to write an out-of-place version of that op! +at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x) +{ + // We expect all tensor inputs to our op to be "functional tensors" + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x)); + // First, sync and unwrap and functional tensors + at::functionalization::impl::sync(x); + auto x_ = at::functionalization::impl::from_functional_tensor(x); + // Grab the dispatcher entry corresponding to the out-of-place op, "x" + static auto op_handle = c10::Dispatcher::singleton() + // specify namespace::op_name, op_overload_name + .findSchemaOrThrow("deepspeed::inference_all_reduce", "") + // Specify the C++ schema of the out-of-place op. + .typed(); + // Next, redispatch to the out-of-place op, x() (user called x_, we call x) + at::Tensor tmp_output; + { + at::AutoDispatchSkipFunctionalize guard; + tmp_output = op_handle.call(x_); + } + // Finally, tell functionalization about this mutation. + at::functionalization::impl::replace_(x, tmp_output); + at::functionalization::impl::commit_update(x); + at::functionalization::impl::sync(x); + return x; +} + +TORCH_LIBRARY_IMPL(deepspeed, CPU, m) +{ + m.impl("inference_all_reduce", inference_all_reduce_cpu); + m.impl("inference_all_reduce_", inference_all_reduce__cpu); +} + +TORCH_LIBRARY_IMPL(deepspeed, Meta, m) +{ + m.impl("inference_all_reduce", inference_all_reduce_meta); + m.impl("inference_all_reduce_", inference_all_reduce__meta); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m) { - m.def("initialize", &initialize, "shm initialize"); - m.def("get_rank", &get_rank, "get rank"); - m.def("get_world_size", &get_world_size, "get world size"); - m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation"); + m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue); } diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 14d4f3847315..83754e98f033 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -151,11 +151,12 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) - @compiler.disable def inference_all_reduce(self, tensor, op, group=None): - if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1: + if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False) + else: + return torch.ops.deepspeed.inference_all_reduce_(tensor) @compiler.disable def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): From 61e07786d5ee3926cd81853a97087f829172ad2c Mon Sep 17 00:00:00 2001 From: Max Kovalenko <75629718+deepcharm@users.noreply.github.com> Date: Tue, 16 Jul 2024 03:00:25 +0300 Subject: [PATCH 5/8] Added wrappers for hpu tensors based on dtype (#5771) This avoids graph breaks when using torch.compile. --- accelerator/hpu_accelerator.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index dd87461696cf..5c2e92c9ef69 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import functools import os import pkgutil import importlib @@ -196,31 +197,31 @@ def replay_graph(self, graph): # Tensor operations @property def BFloat16Tensor(self): - return self.hpu.BFloat16Tensor + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu') @property def ByteTensor(self): - return self.hpu.ByteTensor + return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu') @property def DoubleTensor(self): - return self.hpu.DoubleTensor + return functools.partial(torch.tensor, dtype=torch.double, device='hpu') @property def FloatTensor(self): - return self.hpu.FloatTensor + return functools.partial(torch.tensor, dtype=torch.float, device='hpu') @property def HalfTensor(self): - return self.hpu.HalfTensor + return functools.partial(torch.tensor, dtype=torch.half, device='hpu') @property def IntTensor(self): - return self.hpu.IntTensor + return functools.partial(torch.tensor, dtype=torch.int, device='hpu') @property def LongTensor(self): - return self.hpu.LongTensor + return functools.partial(torch.tensor, dtype=torch.long, device='hpu') def pin_memory(self, tensor, align_bytes=1): return tensor.pin_memory(self.device()) From 98272d14fe0be043c47b5637cc560c6c49ea88ce Mon Sep 17 00:00:00 2001 From: billishyahao Date: Tue, 16 Jul 2024 08:54:02 +0800 Subject: [PATCH 6/8] [bugfix] promote state in bf16_optimizer (#5767) This patch is to promote state in bf16_optimizer so it can be accessible in downstream deepspeed usecase. For example, without the patch, we found issue in megatron-deepspeed llama showcase: ``` [rank3]: Traceback (most recent call last): [rank3]: File "/yahao/Megatron-DeepSpeed/pretrain_gpt.py", line 356, in [rank3]: pretrain(train_valid_test_datasets_provider, [rank3]: File "/yahao/Megatron-DeepSpeed/megatron/training.py", line 222, in pretrain [rank3]: iteration = train(forward_step_func, [rank3]: File "/yahao/Megatron-DeepSpeed/megatron/training.py", line 1264, in train [rank3]: report_memory_flag = training_log(loss_dict, total_loss_dict, [rank3]: File "/yahao/Megatron-DeepSpeed/megatron/training.py", line 999, in training_log [rank3]: opt_stats[0] += (torch.norm(optimizer.state[param]['exp_avg_sq']).item())**2 [rank3]: AttributeError: 'BF16_Optimizer' object has no attribute 'state' ``` With the patch, the invocation can pass smoothly. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/bf16_optimizer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 965b446163ec..325188f02931 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -540,6 +540,11 @@ def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups + @property + def state(self): + """Forward the wrapped optimizer's states.""" + return self.optimizer.state + def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True) From acbaca3223adc38cbacd977b887be6baf34a197f Mon Sep 17 00:00:00 2001 From: Dogacan Colak Date: Tue, 16 Jul 2024 09:31:20 -0400 Subject: [PATCH 7/8] Launcher mode with SSH bypass (#5728) https://github.com/microsoft/DeepSpeed/issues/5510 Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/launcher/runner.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 0ca89dd0497a..12dd629b1b34 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -117,6 +117,12 @@ def parse_args(args=None): help="(optional) IP address of node 0, will be " "inferred via 'hostname -I' if not specified.") + parser.add_argument("--node_rank", + default=-1, + type=int, + help="ID of each node in the range [0:N). " + "Only required when --no_ssh is set.") + parser.add_argument("--launcher", default=PDSH_LAUNCHER, type=str, @@ -145,6 +151,10 @@ def parse_args(args=None): help="Do not pass local_rank as an argument when calling " "the user's training script.") + parser.add_argument("--no_ssh", + action="store_true", + help="Launch training independently on each node without ssh setup.") + parser.add_argument("--no_ssh_check", action="store_true", help="Do not perform ssh check in multi-node launcher model") @@ -427,7 +437,7 @@ def main(args=None): env = os.environ.copy() # validate that passwordless-ssh is workly properly with this hostfile - if multi_node_exec and not args.no_ssh_check: + if multi_node_exec and not args.no_ssh_check and not args.no_ssh: first_host = list(active_resources.keys())[0] try: ssh_check_cmd = "ssh -o PasswordAuthentication=no " @@ -483,16 +493,22 @@ def main(args=None): if args.elastic_training: assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training" + if args.no_ssh: + assert (0 <= args.node_rank < + len(active_resources)), "Launching training without ssh, but --node_rank is not set correctly." + # encode world info as base64 to make it easier to pass via command line world_info_base64 = encode_world_info(active_resources) - multi_node_exec = args.force_multi or len(active_resources) > 1 + multi_node_exec = (args.force_multi or len(active_resources) > 1) and not args.no_ssh if not multi_node_exec: deepspeed_launch = [ sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}", f"--master_addr={args.master_addr}", f"--master_port={args.master_port}" ] + if args.no_ssh: + deepspeed_launch.append(f"--node_rank={args.node_rank}") if args.no_python: deepspeed_launch.append("--no_python") if args.module: From 78c6c449c967e0c0f4755fc8256376b177ef84ef Mon Sep 17 00:00:00 2001 From: beep-bebop <41529995+beep-bebop@users.noreply.github.com> Date: Tue, 16 Jul 2024 21:32:16 +0800 Subject: [PATCH 8/8] Update the list of supported models in the Chinese README of fastgen (#5773) Updates to the three models supported in deepspeed-fastgen since the last Chinese README update. Co-authored-by: weifangyuan --- blogs/deepspeed-fastgen/chinese/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/blogs/deepspeed-fastgen/chinese/README.md b/blogs/deepspeed-fastgen/chinese/README.md index fb9cc7319ab6..1e92e4169450 100644 --- a/blogs/deepspeed-fastgen/chinese/README.md +++ b/blogs/deepspeed-fastgen/chinese/README.md @@ -226,6 +226,10 @@ DeepSpeed-FastGen 是 [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII * [LLaMA](https://huggingface.co/models?other=llama) 和 [LLaMA-2](https://huggingface.co/models?other=llama-2) * [Mistral](https://huggingface.co/models?other=mistral) * [OPT](https://huggingface.co/models?other=opt) +* [Falcon](https://huggingface.co/models?other=falcon) +* [Mixtral](https://huggingface.co/models?other=mixtral) +* [Phi-2](https://huggingface.co/models?other=phi-msft) +* [Qwen](https://huggingface.co/models?other=qwen) 所有当前模型都利用了后端的 [HuggingFace](https://github.com/huggingface) API 来提供模型权重和模型对应的分词器。