diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index dd87461696cf..5c2e92c9ef69 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import functools import os import pkgutil import importlib @@ -196,31 +197,31 @@ def replay_graph(self, graph): # Tensor operations @property def BFloat16Tensor(self): - return self.hpu.BFloat16Tensor + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu') @property def ByteTensor(self): - return self.hpu.ByteTensor + return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu') @property def DoubleTensor(self): - return self.hpu.DoubleTensor + return functools.partial(torch.tensor, dtype=torch.double, device='hpu') @property def FloatTensor(self): - return self.hpu.FloatTensor + return functools.partial(torch.tensor, dtype=torch.float, device='hpu') @property def HalfTensor(self): - return self.hpu.HalfTensor + return functools.partial(torch.tensor, dtype=torch.half, device='hpu') @property def IntTensor(self): - return self.hpu.IntTensor + return functools.partial(torch.tensor, dtype=torch.int, device='hpu') @property def LongTensor(self): - return self.hpu.LongTensor + return functools.partial(torch.tensor, dtype=torch.long, device='hpu') def pin_memory(self, tensor, align_bytes=1): return tensor.pin_memory(self.device()) diff --git a/blogs/deepspeed-fastgen/chinese/README.md b/blogs/deepspeed-fastgen/chinese/README.md index fb9cc7319ab6..1e92e4169450 100644 --- a/blogs/deepspeed-fastgen/chinese/README.md +++ b/blogs/deepspeed-fastgen/chinese/README.md @@ -226,6 +226,10 @@ DeepSpeed-FastGen 是 [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII * [LLaMA](https://huggingface.co/models?other=llama) 和 [LLaMA-2](https://huggingface.co/models?other=llama-2) * [Mistral](https://huggingface.co/models?other=mistral) * [OPT](https://huggingface.co/models?other=opt) +* [Falcon](https://huggingface.co/models?other=falcon) +* [Mixtral](https://huggingface.co/models?other=mixtral) +* [Phi-2](https://huggingface.co/models?other=phi-msft) +* [Qwen](https://huggingface.co/models?other=qwen) 所有当前模型都利用了后端的 [HuggingFace](https://github.com/huggingface) API 来提供模型权重和模型对应的分词器。 diff --git a/csrc/cpu/comm/shm_interface.cpp b/csrc/cpu/comm/shm_interface.cpp index d11c8cfa7375..5be5cb799a7b 100644 --- a/csrc/cpu/comm/shm_interface.cpp +++ b/csrc/cpu/comm/shm_interface.cpp @@ -46,15 +46,13 @@ void initialize(int size, int rank) if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); } } -int get_rank(int group = 0) { return world_rank; } - -int get_world_size(int group = 0) { return world_size; } +void inference_all_reduce_(torch::Tensor& data, int op); // Success - return 0 // Fail (cannot hornor the request and need to fall back) - return -1 -int inference_all_reduce(torch::Tensor& data, py::object op) +void inference_all_reduce_(torch::Tensor& data, int op) { - if (!all_ranks_local_p) return -1; + assert(op == 0); #ifdef DO_PROFILE static double total_time = 0.0; static double total_time_sq = 0.0; @@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op) auto start = std::chrono::system_clock::now(); #endif - static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); - static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); - - assert(py::int_(op.attr("value")) == ReduceOpSum); - auto numel = data.numel(); int data_size = 0; @@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op) default: data_type_fallback = true; } - if (data_type_fallback) return -1; + if (data_type_fallback) return; all_reduce_outer_loop(data, numel, data_size); @@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op) } } #endif - return 0; + return; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); } + +TORCH_LIBRARY(deepspeed, m) +{ + m.def("inference_all_reduce(Tensor self) -> Tensor"); + m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)"); +} + +torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_) +{ + torch::Tensor result_ = torch::empty_like(self_); + return result_; +} + +torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; } + +torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_) +{ + TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU); + torch::Tensor self_tensor = self_.contiguous(); + inference_all_reduce_(self_tensor, 0); + return self_; +} + +torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_) +{ + torch::Tensor result = self_.clone(); + inference_all_reduce__cpu(result); + return result; +} + +#include +// The boilerplate functionalization logic, that teaches functionalization +// how to map x_() calls into x() calls. +// Long term, we'd like to not require users to write this logic. +// HOWEVER, if you have a custom op that is mutable, +// You will still need to write an out-of-place version of that op! +at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x) +{ + // We expect all tensor inputs to our op to be "functional tensors" + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x)); + // First, sync and unwrap and functional tensors + at::functionalization::impl::sync(x); + auto x_ = at::functionalization::impl::from_functional_tensor(x); + // Grab the dispatcher entry corresponding to the out-of-place op, "x" + static auto op_handle = c10::Dispatcher::singleton() + // specify namespace::op_name, op_overload_name + .findSchemaOrThrow("deepspeed::inference_all_reduce", "") + // Specify the C++ schema of the out-of-place op. + .typed(); + // Next, redispatch to the out-of-place op, x() (user called x_, we call x) + at::Tensor tmp_output; + { + at::AutoDispatchSkipFunctionalize guard; + tmp_output = op_handle.call(x_); + } + // Finally, tell functionalization about this mutation. + at::functionalization::impl::replace_(x, tmp_output); + at::functionalization::impl::commit_update(x); + at::functionalization::impl::sync(x); + return x; +} + +TORCH_LIBRARY_IMPL(deepspeed, CPU, m) +{ + m.impl("inference_all_reduce", inference_all_reduce_cpu); + m.impl("inference_all_reduce_", inference_all_reduce__cpu); +} + +TORCH_LIBRARY_IMPL(deepspeed, Meta, m) +{ + m.impl("inference_all_reduce", inference_all_reduce_meta); + m.impl("inference_all_reduce_", inference_all_reduce__meta); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m) { - m.def("initialize", &initialize, "shm initialize"); - m.def("get_rank", &get_rank, "get rank"); - m.def("get_world_size", &get_world_size, "get world size"); - m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation"); + m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue); } diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 14d4f3847315..83754e98f033 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -151,11 +151,12 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) - @compiler.disable def inference_all_reduce(self, tensor, op, group=None): - if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1: + if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False) + else: + return torch.ops.deepspeed.inference_all_reduce_(tensor) @compiler.disable def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 0ca89dd0497a..12dd629b1b34 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -117,6 +117,12 @@ def parse_args(args=None): help="(optional) IP address of node 0, will be " "inferred via 'hostname -I' if not specified.") + parser.add_argument("--node_rank", + default=-1, + type=int, + help="ID of each node in the range [0:N). " + "Only required when --no_ssh is set.") + parser.add_argument("--launcher", default=PDSH_LAUNCHER, type=str, @@ -145,6 +151,10 @@ def parse_args(args=None): help="Do not pass local_rank as an argument when calling " "the user's training script.") + parser.add_argument("--no_ssh", + action="store_true", + help="Launch training independently on each node without ssh setup.") + parser.add_argument("--no_ssh_check", action="store_true", help="Do not perform ssh check in multi-node launcher model") @@ -427,7 +437,7 @@ def main(args=None): env = os.environ.copy() # validate that passwordless-ssh is workly properly with this hostfile - if multi_node_exec and not args.no_ssh_check: + if multi_node_exec and not args.no_ssh_check and not args.no_ssh: first_host = list(active_resources.keys())[0] try: ssh_check_cmd = "ssh -o PasswordAuthentication=no " @@ -483,16 +493,22 @@ def main(args=None): if args.elastic_training: assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training" + if args.no_ssh: + assert (0 <= args.node_rank < + len(active_resources)), "Launching training without ssh, but --node_rank is not set correctly." + # encode world info as base64 to make it easier to pass via command line world_info_base64 = encode_world_info(active_resources) - multi_node_exec = args.force_multi or len(active_resources) > 1 + multi_node_exec = (args.force_multi or len(active_resources) > 1) and not args.no_ssh if not multi_node_exec: deepspeed_launch = [ sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}", f"--master_addr={args.master_addr}", f"--master_port={args.master_port}" ] + if args.no_ssh: + deepspeed_launch.append(f"--node_rank={args.node_rank}") if args.no_python: deepspeed_launch.append("--no_python") if args.module: diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 1c76cbc0a6ef..6f4c170d8295 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -134,7 +134,7 @@ def is_load_module(module): load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", - "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding" + "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding" ] return module.__class__ in load_layers or module._get_name() in load_layer_names diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 965b446163ec..325188f02931 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -540,6 +540,11 @@ def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups + @property + def state(self): + """Forward the wrapped optimizer's states.""" + return self.optimizer.state + def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True) diff --git a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py index 03dfe663fcb6..66a372877d38 100644 --- a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py +++ b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py @@ -8,6 +8,7 @@ from deepspeed.ops.op_builder import AsyncIOBuilder from deepspeed import comm as dist +import torch from deepspeed.runtime.swap_tensor.constants import * from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object @@ -154,6 +155,8 @@ def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors): def _complete_swap_out(self, swap_out_type): self.swap_ops[swap_out_type].wait() + for buffer in self.swap_ops[swap_out_type].state_buffers: + buffer = torch.Tensor() self.swap_buffer_manager.free(self.swap_ops[swap_out_type].allocated_buffers) self.swap_ops[swap_out_type] = None diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index edeedb691961..37b81d42c0d6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -338,7 +338,6 @@ def __init__( self.params_in_ipg_bucket = [] self.params_already_reduced = {} - self.is_gradient_accumulation_boundary = True self._release_ipg_buffers() self.previous_reduced_grads = None