From da771ed42e41a44d5047813ca4672f1cfe9d1731 Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Tue, 17 Dec 2024 06:14:53 +0800 Subject: [PATCH 01/13] Add MLP/lm_head tp grain size setting. (#6828) This PR aims to add MLP/lm_head tp size granularity setting to deepspeed.init_inference() API. It will be more flexible to set the MLP/lm_head sharding grain size. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size. We aim to be able to set the MLP/lm_head tp grain size flexibly. This is a preliminary solution. If there is a better solution, we can discuss it together. Thanks~ --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- deepspeed/inference/config.py | 3 +++ deepspeed/module_inject/replace_module.py | 5 ++++- deepspeed/module_inject/tp_shard.py | 11 ++++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index c7c7684fff79..42ffebbc4386 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): tp_size: int = 1 """ Number of devices to split the model across using tensor parallelism. """ + tp_grain_size: int = 64 + "Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size." + mpu: object = None """ A model parallelism unit object that implements diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7afe6ca903fb..e59f84bc8453 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -17,7 +17,7 @@ from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size from .load_checkpoint import load_model_with_checkpoint import time @@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): if hasattr(model_config, 'num_attention_heads'): set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 4.4 set tp_grain_size + set_tp_grain_size(config.tensor_parallel.tp_grain_size) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 57be0c793856..3e6fc2b63ef1 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -22,6 +22,11 @@ def set_n_embd(num): n_embd = num +def set_tp_grain_size(num): + global tp_grain_size + tp_grain_size = num + + def get_num_kv_heads(): global num_kv_heads if 'num_kv_heads' in globals(): @@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None): my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: - if total_size >= 64: - grain_size = total_size // 64 - return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64 + if total_size >= tp_grain_size: + grain_size = total_size // tp_grain_size + return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size else: return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0) From a964e435532699908e5750abdb027ae583ff793d Mon Sep 17 00:00:00 2001 From: Aviv Keshet Date: Tue, 17 Dec 2024 09:33:09 -0800 Subject: [PATCH 02/13] Fix --enable_each_rank_log when used with PDSH multi-node runner (#6863) This PR addresses fixes https://github.com/microsoft/DeepSpeed/issues/6859 by threading this argument into the deepspeed launcher command build by PDSHRunner. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/launcher/multinode_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 74d20a6d53e5..fe2fa1b476be 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -104,6 +104,8 @@ def get_cmd(self, environment, active_resources): deepspeed_launch.append("--no_local_rank") if self.args.save_pid: deepspeed_launch += ["--save_pid", f"{os.getpid()}"] + if self.args.enable_each_rank_log: + deepspeed_launch.append(f"--enable_each_rank_log={self.args.enable_each_rank_log}") if self.args.elastic_training: deepspeed_launch.append("--enable_elastic_training") deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}") From 2f32966b1cd874aa4373177c8f8c4214ad57d020 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:53:47 -0800 Subject: [PATCH 03/13] Update transformers ops unit tests to use `requried_torch_version` (#6884) --- .../ops/transformer/inference/test_bias_geglu.py | 2 -- .../ops/transformer/inference/test_bias_gelu.py | 2 -- .../ops/transformer/inference/test_bias_relu.py | 2 -- tests/unit/ops/transformer/inference/test_gelu.py | 14 +++++--------- .../unit/ops/transformer/inference/test_matmul.py | 1 - .../unit/ops/transformer/inference/test_softmax.py | 2 -- 6 files changed, 5 insertions(+), 18 deletions(-) diff --git a/tests/unit/ops/transformer/inference/test_bias_geglu.py b/tests/unit/ops/transformer/inference/test_bias_geglu.py index 05de4fbb4cf8..c995d2a8c46d 100644 --- a/tests/unit/ops/transformer/inference/test_bias_geglu.py +++ b/tests/unit/ops/transformer/inference/test_bias_geglu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_geglu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index b69030e87ace..e3a3bad63961 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -16,8 +16,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_gelu_reference(activations, bias): # Expected behavior is that of casting to float32 internally and using the tanh approximation diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index 57134665b241..69078f9f7646 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_relu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py index 5f820ef3b579..a58abfdb100c 100644 --- a/tests/unit/ops/transformer/inference/test_gelu.py +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -9,12 +9,11 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype @@ -23,14 +22,11 @@ def allclose(x, y): def version_appropriate_gelu(activations): - global torch_minor_version - if torch_minor_version is None: - torch_minor_version = int(torch.__version__.split('.')[1]) - # If torch version = 1.12 - if torch_minor_version < 12: - return torch.nn.functional.gelu(activations) - else: + # gelu behavior changes (correctly) in torch 1.12 + if required_torch_version(min_version=1.12): return torch.nn.functional.gelu(activations, approximate='tanh') + else: + return torch.nn.functional.gelu(activations) def run_gelu_reference(activations): diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 559aa2c60afe..2ab195ee0115 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -12,7 +12,6 @@ pytest.skip("Inference ops are not available on this system", allow_module_level=True) inference_module = None -torch_minor_version = None def allclose(x, y): diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py index e582be1b926a..83785ac38ebb 100644 --- a/tests/unit/ops/transformer/inference/test_softmax.py +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype From 4cd1d97460b677563d57f07a293724bdc02e0ef5 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:30:52 -0800 Subject: [PATCH 04/13] Don't error out when cpu accelerator doesn't have torch (as default for whl building) (#6886) This fixes a bug introduced in #6845, which breaks the `no-torch` workflow that we require in order to do releases where we do not require torch to be in the environment when building an sdist. This adds the same logic to the cpuaccelerator that the cudaaccelerator had where we don't require torch to be installed to build the whl. --- .github/workflows/no-torch.yml | 1 + accelerator/cpu_accelerator.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/no-torch.yml b/.github/workflows/no-torch.yml index 1a13c0f3f4f1..5b89a6f36787 100644 --- a/.github/workflows/no-torch.yml +++ b/.github/workflows/no-torch.yml @@ -4,6 +4,7 @@ on: workflow_dispatch: pull_request: paths: + - 'accelerator/**' - '.github/workflows/no-torch.yml' - 'op_builder/**' schedule: diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 0e49bd9f6458..de711f73144e 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -3,9 +3,15 @@ # DeepSpeed Team -import torch from .abstract_accelerator import DeepSpeedAccelerator +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch +except ImportError as e: + pass + try: import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore oneccl_imported_p = True From 0b25630abe8f7dd4e64c277ff92f5f7e36a27284 Mon Sep 17 00:00:00 2001 From: Daniel Huang Date: Wed, 18 Dec 2024 08:09:31 -0800 Subject: [PATCH 05/13] Add arctic model support by adding w2 to all_reduce (#6856) As title says. Default behavior of arctic model produces shape issues with AutoTP due to the MLP layer performing `w2 * act(w1*w3)`. However, method provided to fix Mixtral-7x8b in #5257 does not work since the MLP for Arctic is also used within a ModuleList for the MoE. This results in MLP weights hiding behind individual experts as layers `#.w#`, which is not caught by the fix in #5257. This adds the check directly within replace, where it can check for actual layer names for the `w2` key in the model to patch with `all_reduce`. --------- Signed-off-by: Daniel Huang Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/auto_tp.py | 6 +++++- docs/_tutorials/automatic-tensor-parallelism.md | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 221d490a37d2..5441000e581d 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -346,11 +346,15 @@ def _replace(self, child, name, conv_linear_layer): weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size(), False) return LinearAllreduce(weight, bias, self.mp_group) + # For Arctic model, bypass to all_reduce replacement for w2 weights + arctic_w2_all_reduce_linear = False + if 'Arctic' in str(self.module) and 'w2' in name: + arctic_w2_all_reduce_linear = True # For MLP including chunk layer. if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) return LinearLayer(weight=weight, bias=bias) - if name in self.all_reduce_linears: + if name in self.all_reduce_linears or arctic_w2_all_reduce_linear: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/docs/_tutorials/automatic-tensor-parallelism.md b/docs/_tutorials/automatic-tensor-parallelism.md index d5a08b27bf4d..6488f9b718fe 100755 --- a/docs/_tutorials/automatic-tensor-parallelism.md +++ b/docs/_tutorials/automatic-tensor-parallelism.md @@ -121,6 +121,7 @@ The following results were collected using V100 SXM2 32GB GPUs. The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet. - albert +- arctic - baichuan - bert - bigbird_pegasus From b344c04df0fdf058617004924a7aaa15055dccce Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 18 Dec 2024 11:49:28 -0500 Subject: [PATCH 06/13] Update code owners (#6890) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- CODEOWNERS | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index c0fc85cb8b89..b0d3b8b0d77b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -8,49 +8,52 @@ # top-level repo folders /.github/ @loadams -/azure/ @awan-10 -/benchmarks/ @awan-10 @tjruwase +/azure/ @loadams +/benchmarks/ @guanhuawang @tjruwase /bin/ @loadams -/csrc/ @awan-10 +/csrc/ @tjruwase /deepspeed/ @loadams @tjruwase -/docker/ @awan-10 +/docker/ @loadams @guanhuawang /docs/ @loadams @tjruwase -/examples/ @awan-10 @tohtana +/examples/ @jomayeri @tohtana /op_builder/ @loadams @tjruwase @jomayeri -/release/ @loadams +/release/ @loadams @jomayeri /requirements/ @loadams -/scripts/ @awan-10 +/scripts/ @loadams @tjruwase /tests/ @tjruwase @loadams @tohtana # deepspeed /deepspeed/autotuning/ @loadams /deepspeed/checkpoint/ @tjruwase -/deepspeed/comm/ @awan-10 +/deepspeed/comm/ @guanhuawang /deepspeed/compression/ @tjruwase -/deepspeed/elasticity/ @awan-10 +/deepspeed/elasticity/ @tjruwase /deepspeed/launcher/ @loadams -/deepspeed/module_inject/ @awan-10 +/deepspeed/module_inject/ @hwchen2017 @loadams /deepspeed/moe/ @tohtana -/deepspeed/monitor/ @awan-10 +/deepspeed/monitor/ @tjruwase /deepspeed/nebula/ @tjruwase +/deepspeed/nvme/ @tjruwase @jomayeri /deepspeed/ops/ @tohtana /deepspeed/pipe/ @tohtana @loadams /deepspeed/profiling/ @loadams -/deepspeed/utils/ @tjruwase @awan-10 +/deepspeed/sequence/ @tohtana +/deepspeed/utils/ @tjruwase @tohtana # inference -/deepspeed/inference/ @awan-10 -/deepspeed/model_implementations/ @awan-10 +/deepspeed/inference/ @hwchen2017 @tohtana +/deepspeed/model_implementations/@tohtana @loadams # training /deepspeed/runtime/ @tjruwase @tohtana /deepspeed/runtime/activation_checkpointing/ @tjruwase /deepspeed/runtime/checkpoint_engine/ @tjruwase -/deepspeed/runtime/comm/ @awan-10 -/deepspeed/runtime/compression/ @awan-10 +/deepspeed/runtime/comm/ @guanhuawang +/deepspeed/runtime/compression/ @tjruwase /deepspeed/runtime/data_pipeline/ @tjruwase -/deepspeed/runtime/fp16/ @tjruwase -/deepspeed/runtime/fp16/onebit/ @awan-10 -/deepspeed/runtime/pipe/ @loadams -/deepspeed/runtime/swap_tensor/ @tjruwase -/deepspeed/runtime/zero/ @tjruwase +/deepspeed/runtime/domino/ @guanhuawang @hwchen2017 +/deepspeed/runtime/fp16/ @tjruwase @tohtana +/deepspeed/runtime/fp16/onebit/ @tjruwase +/deepspeed/runtime/pipe/ @loadams @tohtana +/deepspeed/runtime/swap_tensor/ @tjruwase @jomayeri +/deepspeed/runtime/zero/ @tjruwase @tohtana From f9e158a0f5cfa08b475cc1f086accffd8a77b92f Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:53:17 -0800 Subject: [PATCH 07/13] Update version.txt after 0.16.2 release (#6893) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.16.2 Author - @loadams Co-authored-by: loadams --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 201a22c8fa5c..7eb3095a3295 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.16.2 +0.16.3 From 4fd79205c6b85b47e00810143c69d342ce874ede Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Thu, 19 Dec 2024 19:26:50 +0200 Subject: [PATCH 08/13] Allow to compile collective for PT>2.3 (#6899) Allow to compile collective for PT>2.3 commit re-uploaded due to github CI issue originally uploaded by @nelyahu --- deepspeed/comm/torch.py | 50 +++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 988b74232bb9..5461ae18d1f0 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -20,6 +20,12 @@ DS_COMM_REDUCE_OFF = False +def disable_compiler_collective(func): + if required_torch_version(min_version=2.3): + return func + return compiler.disable(func) + + def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: @@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) @classmethod - @compiler.disable + @disable_compiler_collective def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -123,7 +129,7 @@ def get_all_gather_function(self): return None @classmethod - @compiler.disable + @disable_compiler_collective def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' - @compiler.disable + @disable_compiler_collective def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) @@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None): else: return torch.ops.deepspeed.inference_all_reduce_(tensor) - @compiler.disable + @disable_compiler_collective def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -169,7 +175,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -177,7 +183,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -208,7 +214,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -216,7 +222,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() - @compiler.disable + @disable_compiler_collective def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_to_all_single(self, output, input, @@ -287,27 +293,27 @@ def all_to_all_single(self, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -315,7 +321,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -323,13 +329,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) - @compiler.disable + @disable_compiler_collective def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD From 00ea0c46c2296db158d10497602f9832c4445d84 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Fri, 20 Dec 2024 02:54:45 +0200 Subject: [PATCH 09/13] Zero2: avoid graph breaks in torch.compile by using param_idx (#6803) inside reduce_independent_p_g_buckets_and_remove_grads and in reduce_ipg_grads which are being executed during the BWD hook in zero2, the model param is being stored inside params_in_ipg_bucket. torch.compile has hard time tracing parameters. By using the param's static index inside the group the same logic can be maintain with less complexity. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams --- deepspeed/runtime/zero/stage_1_and_2.py | 9 ++++++--- tests/unit/moe/test_moe.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 7ac89a233808..ecb2a527f870 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -310,6 +310,7 @@ def __init__(self, for param in param_group['params']: if param.requires_grad: param.grad_accum = None + param.param_idx_in_group = len(trainable_parameters) trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" self.grads_in_ipg_bucket.append(grad_reduc) - self.params_in_ipg_bucket.append((i, param, param_id)) + self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for _, param, param_id in self.params_in_ipg_bucket: + for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[group_idx][param_idx_in_group] assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 9ee546437f6c..c67a907c6785 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -93,7 +93,8 @@ def strict_average_tensor(tensor): process_group = optimizer.dp_process_group curr_size = 0 pg_offsets = [] - for i, param, param_id in optimizer.params_in_ipg_bucket: + for i, param_idx, param_id in optimizer.params_in_ipg_bucket: + param = optimizer.bit16_groups[i][param_idx] process_group = optimizer.dp_process_group if optimizer.ipg_bucket_has_moe_params: process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param( From eea5304807c6a04d0f2c55cb935ec295235d9b54 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:13:46 +0200 Subject: [PATCH 10/13] hpu_accelerator: use torch.use_deterministic_algorithms (#6897) formal API instead of hpu.setDeterministic --- accelerator/hpu_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 723a66e4c6fb..b46351f8ca43 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -21,8 +21,8 @@ def __init__(self): self.apply_hpu_workarounds() try: import habana_frameworks.torch.hpu as hpu - hpu.setDeterministic(True) self.hpu = hpu + torch.use_deterministic_algorithms(True) except ImportError as e: raise ValueError( f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") From 85cc5f9bb3f0175a2d13ea1ed65bf7d202b7f0d9 Mon Sep 17 00:00:00 2001 From: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Date: Thu, 26 Dec 2024 09:12:04 -0800 Subject: [PATCH 11/13] Fix error caused by all_reduce call in domino (#6880) Fix #6851 Initialize communication backend to fix error caused by all_reduce call in the Domino transformer layer. Verified correctness in local test. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/domino/transformer.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 8eb95e49c29d..88c5494c8147 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,8 +6,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed -from deepspeed import comm as dist +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -97,7 +96,7 @@ def backward(ctx, grad_output): return grad_output # Async All-reduce. - handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) + handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) ctx.handle_dic[ctx.h_id] = handle return None, grad_output, None, None @@ -249,6 +248,10 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm is not initialized!" + self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): layernorm_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = deepspeed.comm.all_reduce(attention_output0, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) attention_output1, attention_bias1 = \ self.self_attention( layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle1 = deepspeed.comm.all_reduce(attention_output1, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle0.wait() # Residual0 connection. @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): output0 = output0 + bias_c output0 = self.mlp_activation_func(output0) output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle1.wait() @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if bias_c is not None: output1 = output1 + bias_c output1 = torch.matmul(output1, self.weight_r.t()) - deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) + dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) handle2.wait() From cc03c76d57f41752d8cfb84c2e45b8e0da8083da Mon Sep 17 00:00:00 2001 From: Raza Sikander Date: Fri, 27 Dec 2024 01:37:28 +0530 Subject: [PATCH 12/13] Update Gaudi2 jobs to latest 1.19 build (#6905) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .github/workflows/hpu-gaudi2-nightly.yml | 2 +- .github/workflows/hpu-gaudi2.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml index 5c5caff1ebb0..c0576360cd61 100644 --- a/.github/workflows/hpu-gaudi2-nightly.yml +++ b/.github/workflows/hpu-gaudi2-nightly.yml @@ -21,7 +21,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index a06f871b7c56..b8b6f3cb5502 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice From 3573858e7ce2c723b8c43231c6c6b0cf97dca2fc Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Mon, 30 Dec 2024 20:53:41 +0200 Subject: [PATCH 13/13] Change compile for pipeline module torch.compile (#6478) We have encountered and issue with torch.compile and the pipeline module. modifying a member of the module (micro_offset) during the forward function will cause torch compile to restart the analysis and treat the module as dynamic. In order to bypass this issue without significantly changing the way the pipeline module works we propose to compile only the layers in the pipeline module instead of the forward function of pipeline module. this will bypass the issue and should still give most of the benefit of torch compiling the pipeline module while avoiding the issue. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/pipe/module.py | 8 ++++++++ tests/unit/pipe/test_pipe_module.py | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31fec30be788..9fbd91f750a9 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -662,3 +662,11 @@ def get_additional_losses(self): Return a dictionary of {"loss name": loss_value} or None if no additional losses. """ return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 05c6a82ef55a..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -60,9 +60,12 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True @pytest.mark.parametrize("activation_checkpoints", [False, True]) - def test(self, sequential_model, simple_config, batch_input, activation_checkpoints): + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())