From 9922270f47be0c349bd9ccce033d5ccdd22bd764 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Mon, 5 Feb 2024 12:52:51 -0800 Subject: [PATCH 1/8] Further refactor deepspeed.moe.utils + deepspeed.moe.layer type hints (#5060) When unpacking a `dict`, keys that appear after the unpacking can overwrite the keys of the unpacked `dict`, meaning we can avoid avoid the pattern of skipping certain keys; also use `defaultdict` to avoid having to do the boilerplate of assigning the elements of `group_moe`. More type hints and small stylistic changes to `deepspeed.moe.layer` --------- Co-authored-by: Michael Wyatt Co-authored-by: Michael Wyatt --- deepspeed/moe/layer.py | 55 ++++++++++++++++++++++-------------------- deepspeed/moe/utils.py | 29 +++++++++------------- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index 7dd0c6bcb67d..46f7924ac038 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -3,22 +3,23 @@ # DeepSpeed Team -import torch +from typing import Optional, Tuple -from deepspeed.utils import log_dist +import torch +from torch import nn +from torch.nn import functional as F -from deepspeed.utils import groups -from .sharded_moe import MOELayer, TopKGate +from deepspeed.utils import groups, log_dist from .experts import Experts -import typing +from .sharded_moe import MOELayer, TopKGate -class MoE(torch.nn.Module): +class MoE(nn.Module): """Initialize an MoE layer. Arguments: hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension. - expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear). + expert (nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear). num_experts (int, optional): default=1, the total number of experts per layer. ep_size (int, optional): default=1, number of ranks in the expert parallel world or group. k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. @@ -34,20 +35,20 @@ class MoE(torch.nn.Module): """ def __init__(self, - hidden_size, - expert, - num_experts=1, - ep_size=1, - k=1, - capacity_factor=1., - eval_capacity_factor=1., - min_capacity=4, - use_residual=False, - noisy_gate_policy: typing.Optional[str] = None, + hidden_size: int, + expert: nn.Module, + num_experts: int = 1, + ep_size: int = 1, + k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 4, + use_residual: bool = False, + noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, - use_rts=True, + use_rts: bool = True, use_tutel: bool = False, - enable_expert_tensor_parallelism: bool = False): + enable_expert_tensor_parallelism: bool = False) -> None: super(MoE, self).__init__() @@ -77,12 +78,12 @@ def __init__(self, if self.use_residual: self.mlp = expert # coefficient is used for weighted sum of the output of expert and mlp - self.coefficient = torch.nn.Linear(hidden_size, 2) + self.coefficient = nn.Linear(hidden_size, 2) - def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False): + def set_deepspeed_parallelism(self, use_data_before_expert_parallel_: bool = False) -> None: self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_) - def _create_process_groups(self, use_data_before_expert_parallel_=False): + def _create_process_groups(self, use_data_before_expert_parallel_: bool = False) -> None: # Create process group for a layer if needed if self.expert_group_name not in groups._get_expert_parallel_group_dict(): print(f"No existing process group found, creating a new group named: {self.expert_group_name}") @@ -98,7 +99,9 @@ def _create_process_groups(self, use_data_before_expert_parallel_=False): # Set the group handle for the MOELayer (deepspeed_moe) object self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name)) - def forward(self, hidden_states, used_token=None): + def forward(self, + hidden_states: torch.Tensor, + used_token: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ MoE forward Arguments: @@ -112,15 +115,15 @@ def forward(self, hidden_states, used_token=None): * l_aux (Tensor): gate loss value - * exp_counts (int): expert count + * exp_counts (Tensor): expert count """ output = self.deepspeed_moe(hidden_states, used_token) if self.use_residual: # Residual MoE output_mlp = self.mlp(hidden_states) - if type(output_mlp) is tuple: + if isinstance(output_mlp, tuple): output_mlp = output_mlp[0] # Ignore the bias term for now coef = self.coefficient(hidden_states) - coef = torch.nn.functional.softmax(coef, dim=-1) + coef = F.softmax(coef, dim=-1) output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index db23aa46712e..8e1faffc3541 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -3,6 +3,7 @@ # DeepSpeed Team +from collections import defaultdict from typing import Any, Dict, List, Set, Tuple, Union, cast import torch @@ -68,10 +69,9 @@ def split_params_grads_into_shared_and_expert_params( return shared_grads, expert_grads -def split_params_into_different_moe_groups_for_optimizer(param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any], - ...], - List[Dict[str, Any]]], - max_group_size: int = 178956971) -> List[Dict[str, Any]]: +def split_params_into_different_moe_groups_for_optimizer( + param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]], + max_group_size: Union[int, float] = 178956971) -> List[Dict[str, Any]]: """Split parameters into different MoE groups for optimizer Args: @@ -97,18 +97,15 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Union[Dic data_parallel_group_names.add(param.group_name) # Create the param MoE groups, leave param assign to next step - group_moe: Dict[str, Dict[str, Dict[str, Any]]] = {} + group_moe: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) for param_group in param_groups: - group_moe[param_group['name']] = {} for key in data_parallel_group_names: - group_moe[param_group['name']][key] = {} - group_moe[param_group['name']][key]['name'] = key - group_moe[param_group['name']][key]['moe'] = True - - for ori_key in param_group.keys(): - if ori_key != 'name': - group_moe[param_group['name']][key][ori_key] = ([] - if ori_key == 'params' else param_group[ori_key]) + group_moe[param_group['name']][key] = { + **param_group, + 'name': key, + 'moe': True, + 'params': [], + } # Assign param for param_group in param_groups: @@ -142,9 +139,7 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Union[Dic all_groups.append(cur_group) for group in all_groups: - new_dict = dict(param_group) - new_dict['params'] = group - param_groups.append(new_dict) + param_groups.append({**param_group, 'params': group}) else: for moe_group in group_moe.values(): for param_group in moe_group.values(): From f02d7bdadf2cee23719d148339a88b24d668a149 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:32:56 -0800 Subject: [PATCH 2/8] Fix verification for ZeRO3 leaf module (#5074) This PR improves verification for ZeRO3 leaf module. The leaf module requires input tensors with `requires_grad=True` to launch reduce_scatter from backward hooks. Currently we throw an error if any of input tensors to the leaf module does not *require grad*. This prevents us from using leaf modules in some scenarios including inference or activation checkpointing as reported in #5008 This PR addresses the issue by checking output tensors as well as input tensors. the hook does not throw an error if no output tensor requires grad. --- deepspeed/runtime/zero/stage3.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ad84eb5bb390..b76b781346e7 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1149,7 +1149,7 @@ def reduce_partition_and_remove_grads(*notneeded): # We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done for leaf_module, leaf_parameters in self.leaf_parameters.items(): - def wrapper(params): + def wrapper_pre_hook(params): def forward_pre_hook(module, input): """Pre-forward hook to set backward hook on input tensors to the leaf module""" @@ -1173,16 +1173,32 @@ def set_module_bwd_hook(tensor): output = apply_to_tensors_only(set_module_bwd_hook, input) - if module._leaf_module_inputs_remaining == 0: - raise RuntimeError( - "A module cannot be set as a leaf module when it does not have any input tensors that require gradients" - ) - return output return forward_pre_hook - self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper(leaf_parameters))) + def wrapper_post_hook(): + + def forward_post_hook(module, input, output): + """Pre-forward hook to set backward hook on input tensors to the leaf module""" + module._leaf_output_required_grad_num = 0 + + def increment_rg_count_bwd_hook(tensor): + if tensor.requires_grad: + module._leaf_output_required_grad_num += 1 + return tensor + + apply_to_tensors_only(increment_rg_count_bwd_hook, output) + + if module._leaf_module_inputs_remaining == 0 and module._leaf_output_required_grad_num > 0: + raise RuntimeError( + "A module cannot be set as a leaf module when it does not have any input tensors that require gradients and has output tensors that require gradients. This is because the gradient reduction hook will not be called in this case." + ) + + return forward_post_hook + + self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper_pre_hook(leaf_parameters))) + self._leaf_module_hooks.append(leaf_module.register_forward_hook(wrapper_post_hook())) print_rank_0(f'[End] Create gradient reduction hooks') From 5a721de32c21082b81e3f6ae4d7d927fcfa7ce39 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:33:29 -0800 Subject: [PATCH 3/8] Stop tracking backward chain of broadcast in initialization (#5075) DeepSpeed engine generates the following warning upon initialization. This warning is triggered by a broadcast that synchronizes model parameters across ranks. Although this is harmless in terms of both accuracy and, likely, performance, it may confuse users and potentially cause compatibility issues with future versions of PyTorch. This PR runs the broadcast within a `torch.no_grad` context to prevent tracking of the backward computation chain. ``` /home/aiscuser/.conda/envs/wbcast/lib/python3.9/site-packages/torch/autograd/__init__.py:266: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1704987277512/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ``` Co-authored-by: Michael Wyatt --- deepspeed/runtime/engine.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ac0830d6f197..58dac4e44b77 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1056,16 +1056,17 @@ def is_replicated(p): return False return True - for p in self.module.parameters(): - # Broadcast the model for different parameters - if is_moe_param(p): - if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, - groups._get_expert_broadcast_src_rank(p.group_name), - group=self.expert_data_parallel_group[p.group_name]) - else: - if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) + with torch.no_grad(): + for p in self.module.parameters(): + # Broadcast the model for different parameters + if is_moe_param(p): + if torch.is_tensor(p) and is_replicated(p): + dist.broadcast(p, + groups._get_expert_broadcast_src_rank(p.group_name), + group=self.expert_data_parallel_group[p.group_name]) + else: + if torch.is_tensor(p) and is_replicated(p): + dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: From 889620b0a44be9622ebcdd65be5a809a2f3c310b Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 5 Feb 2024 16:02:55 -0800 Subject: [PATCH 4/8] Update nv-torch-latest-version --- .github/workflows/nv-torch-latest-cpu.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nv-torch-latest-cpu.yml b/.github/workflows/nv-torch-latest-cpu.yml index 7923997113ed..3f3e2e6f0c6f 100644 --- a/.github/workflows/nv-torch-latest-cpu.yml +++ b/.github/workflows/nv-torch-latest-cpu.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch==2.1.2+cpu torchvision==0.16.2+cpu --extra-index-url https://download.pytorch.org/whl/cpu python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -46,5 +46,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="1.12" - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="1.12" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.1" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.1" From 55eb78ee1fbfe10950121cabc678edb32643a0b1 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 5 Feb 2024 16:03:21 -0800 Subject: [PATCH 5/8] Revert "Update nv-torch-latest-version" This reverts commit 889620b0a44be9622ebcdd65be5a809a2f3c310b. --- .github/workflows/nv-torch-latest-cpu.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nv-torch-latest-cpu.yml b/.github/workflows/nv-torch-latest-cpu.yml index 3f3e2e6f0c6f..7923997113ed 100644 --- a/.github/workflows/nv-torch-latest-cpu.yml +++ b/.github/workflows/nv-torch-latest-cpu.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install torch==2.1.2+cpu torchvision==0.16.2+cpu --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -46,5 +46,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.1" - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.1" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="1.12" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="1.12" From e469e7d98c26602ccebc1f3757ace8132cbb5247 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:14:20 -0800 Subject: [PATCH 6/8] Update torch version for nv-torch-latest-cpu (#5086) Given the name of this test, we should be running a newer version of torch than we were, this updates to a newer version. Total test coverage remains the same: ``` torch 1.12 ========== 151 passed, 891 skipped, 20 warnings in 126.58s (0:02:06) =========== ========= 4 passed, 61 skipped, 5342 deselected, 2 warnings in 48.66s ========== torch 2.2 ========== 151 passed, 880 skipped, 20 warnings in 157.64s (0:02:37) =========== ========= 4 passed, 62 skipped, 5330 deselected, 2 warnings in 55.73s ========== ``` --- .github/workflows/nv-torch-latest-cpu.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nv-torch-latest-cpu.yml b/.github/workflows/nv-torch-latest-cpu.yml index 7923997113ed..4075a46ec913 100644 --- a/.github/workflows/nv-torch-latest-cpu.yml +++ b/.github/workflows/nv-torch-latest-cpu.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -46,5 +46,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="1.12" - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="1.12" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2" From e212845e396e9fb1decce09dfa2800e3488a1704 Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:37:44 -0800 Subject: [PATCH 7/8] Add backwards compatibility w/ older versions of diffusers (<0.25.0) (#5083) This PR adds backwards compatibility for older versions of `diffusers` (`<0.25.0`) by updating the `vae` container import logic to account for changes between the various versions. --- deepspeed/module_inject/containers/vae.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/deepspeed/module_inject/containers/vae.py b/deepspeed/module_inject/containers/vae.py index 297a796977f1..d26d0ef77ca9 100644 --- a/deepspeed/module_inject/containers/vae.py +++ b/deepspeed/module_inject/containers/vae.py @@ -13,11 +13,17 @@ def __init__(self): super().__init__() try: import diffusers - if hasattr(diffusers.models.autoencoders.vae, "AutoencoderKL"): - self._orig_layer_class = diffusers.models.autoencoders.vae.AutoencoderKL - else: - # Diffusers >= 0.12.0 changes location of AutoencoderKL + if hasattr(diffusers.models, "autoencoders"): + # Diffusers >= 0.25.0 + # Changes location to 'autoencoders' directory self._orig_layer_class = diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL + elif hasattr(diffusers.models.vae, "AutoencoderKL"): + # Diffusers < 0.12.0 + self._orig_layer_class = diffusers.models.vae.AutoencoderKL + else: + # Diffusers >= 0.12.0 & < 0.25.0 + # Changes location of AutoencoderKL + self._orig_layer_class = diffusers.models.autoencoder_kl.AutoencoderKL except ImportError: self._orig_layer_class = None From c3cfe96bb3374ae6d8ff200e7487a1562de43e11 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 5 Feb 2024 21:11:49 -0800 Subject: [PATCH 8/8] Enable torch.compile with ZeRO (Experimental) (#4878) This PR enables `torch.compile` with ZeRO stages 1/2/3. You need to add `compile` section in your DeepSpeed config. The fields in the section are passed to `torch.compile`. ```json "compile": { "disable": false, "backend": "inductor" } ``` To enable a custom backend, you can pass the fully qualified name of the backend function. For example, if you have a backend class `my_backend` in `my_backend.py` in the current directory, you can enable it by `"backend": "my_backend.my_backend"`. You can find an example in [a unit test](https://github.com/microsoft/DeepSpeed/blob/eb9d4e06e9596f391aea305a6a5c6ec70cc28b58/tests/unit/runtime/compile/test_config.py#L116). Currently we validated the results with Megatron-DeepSpeed. See the [example](https://github.com/microsoft/Megatron-DeepSpeed/tree/tohtana/enable_compile/examples_deepspeed/compile) for the details. NOTICE: This PR is a draft. We will need to validate the coverage and accuracy with many more examples. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt --- deepspeed/__init__.py | 1 + deepspeed/comm/torch.py | 24 +++ deepspeed/runtime/compiler.py | 166 ++++++++++++++++++ deepspeed/runtime/config.py | 3 + deepspeed/runtime/engine.py | 39 ++-- deepspeed/runtime/zero/mics.py | 3 +- deepspeed/runtime/zero/parameter_offload.py | 162 ++++++++--------- .../runtime/zero/partition_parameters.py | 19 +- .../zero/partitioned_param_coordinator.py | 6 + deepspeed/runtime/zero/utils.py | 23 ++- tests/unit/common.py | 83 +++++++-- .../runtime/compile/test_compile_wrapper.py | 79 +++++++++ .../unit/runtime/compile/test_compile_zero.py | 63 +++++++ .../unit/runtime/compile/test_load_config.py | 114 ++++++++++++ tests/unit/runtime/compile/util.py | 124 +++++++++++++ 15 files changed, 784 insertions(+), 125 deletions(-) create mode 100644 deepspeed/runtime/compiler.py create mode 100644 tests/unit/runtime/compile/test_compile_wrapper.py create mode 100644 tests/unit/runtime/compile/test_compile_zero.py create mode 100644 tests/unit/runtime/compile/test_load_config.py create mode 100644 tests/unit/runtime/compile/util.py diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 61ae016eef4b..a6feae06974b 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -39,6 +39,7 @@ from .runtime import zero from .runtime import DeepSpeedOptimizer, ZeROOptimizer +from .runtime.compiler import is_compile_supported from .pipe import PipelineModule diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 531f54d2628f..d96c63d84e60 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -8,6 +8,7 @@ from .utils import * from .backend import * from .comm import * +from ..runtime import compiler import os DS_COMM_ALL_GATHER_OFF = False @@ -120,6 +121,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.init_process_group(backend, timeout, init_method, rank, world_size) @classmethod + @compiler.disable def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -128,6 +130,7 @@ def get_all_gather_function(self): return None @classmethod + @compiler.disable def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -150,14 +153,17 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' + @compiler.disable def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) + @compiler.disable def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) + @compiler.disable def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -168,6 +174,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) + @compiler.disable def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -175,6 +182,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) + @compiler.disable def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -187,6 +195,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) + @compiler.disable def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -195,6 +204,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) + @compiler.disable def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -203,6 +213,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) + @compiler.disable def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -210,6 +221,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) + @compiler.disable def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -227,6 +239,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass + @compiler.disable def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -250,6 +263,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() + @compiler.disable def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -263,6 +277,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass + @compiler.disable def all_to_all_single(self, output, input, @@ -277,21 +292,27 @@ def all_to_all_single(self, group=group, async_op=async_op) + @compiler.disable def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) + @compiler.disable def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) + @compiler.disable def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) + @compiler.disable def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) + @compiler.disable def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) + @compiler.disable def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -299,6 +320,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) + @compiler.disable def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -306,11 +328,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) + @compiler.disable def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) + @compiler.disable def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py new file mode 100644 index 000000000000..603f563fca60 --- /dev/null +++ b/deepspeed/runtime/compiler.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Union, Callable, Dict, Any +import importlib +import torch +from ..pydantic_v1 import validator +from .config_utils import DeepSpeedConfigModel + +COMPILE_CONFIG = "compile" + + +def is_compile_supported(): + return hasattr(torch, "compile") + + +def disable(func): + if is_compile_supported(): + return torch.compiler.disable(func) + return func + + +def get_compile_config(param_dict): + if COMPILE_CONFIG in param_dict: + compile_config_dict = param_dict[COMPILE_CONFIG] + else: + compile_config_dict = {} + return CompileConfig(**compile_config_dict) + + +def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]: + if isinstance(backend, Callable): + return backend + + elif isinstance(backend, str): + if backend in torch._dynamo.list_backends(): + return backend + + # Get module name from backend name + module_name = '.'.join(backend.split('.')[:-1]) + fn_name = backend.split('.')[-1] + + try: + module = importlib.import_module(module_name) + backend_fn = getattr(module, fn_name) + except ImportError: + raise ValueError( + f"The backend {backend} is not in the list of available backends and could not be imported.") + return backend_fn + + raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}") + + +class CompileConfig(DeepSpeedConfigModel): + """ + [EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings. + Please be aware that these features and API designs are experimental and subject to change. + """ + + enabled: bool = False + """ + Enable torch.compile when True. + """ + + backend: str = "inductor" + """ + Passed to `backend` argument of torch.compile. + If the given value is not in torch._dynamo.list_backends(), + DeepSpeed attempts to import and instantiate the module with the given name. + """ + + kwargs: Dict[str, Any] = {} + """ + Passed to `kwargs` argument of torch.compile. + """ + + @validator("enabled") + def validate_enabled(cls, field_value, values): + if field_value and not is_compile_supported(): + raise ValueError("torch.compile is not supported on this version of PyTorch.") + return field_value + + +class CompiledModuleWrapper(torch.nn.Module): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + super().__init__() + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + modules = self.__dict__.get('_modules') + modules['wrapped'] = module + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def __getattr__(self, name): + return getattr(self.__dict__['wrapped'], name) + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 80754df50c20..20fbf475ca90 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -31,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config from ..inference.config import WeightQuantConfig +from .compiler import get_compile_config from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -899,6 +900,8 @@ def _initialize_params(self, param_dict): self.weight_quantization_config = WeightQuantConfig( **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None + self.compile_config = get_compile_config(param_dict) + def _batch_assertion(self): train_batch = self.train_batch_size diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 58dac4e44b77..f0602813f3ab 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -90,6 +90,7 @@ from .pipe.module import PipelineModule from .utils import get_ma_status +from .compiler import CompiledModuleWrapper from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE @@ -179,21 +180,19 @@ def __init__(self, enable_micro_timers, enable_global_timers): class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" - def __init__( - self, - args, - model, - optimizer=None, - model_parameters=None, - training_data=None, - lr_scheduler=None, - mpu=None, - dist_init_required=None, - collate_fn=None, - config=None, - config_class=None, - dont_change_device=False, - ): + def __init__(self, + args, + model, + optimizer=None, + model_parameters=None, + training_data=None, + lr_scheduler=None, + mpu=None, + dist_init_required=None, + collate_fn=None, + config=None, + config_class=None, + dont_change_device=False): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device self.client_optimizer = optimizer @@ -363,6 +362,9 @@ def __init__( self.flatten = _flatten_dense_tensors self.unflatten = _unflatten_dense_tensors + if self._config.compile_config.enabled: + self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config)) + def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() @@ -467,6 +469,13 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) + elif isinstance(_module, CompiledModuleWrapper): + try: + return getattr(_module, name) + except AttributeError: + raise AttributeError( + f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" + ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index 05e59eccfdae..1e5c9396be1d 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -12,8 +12,9 @@ import deepspeed import torch from deepspeed import comm as dist +from deepspeed.runtime.zero.utils import is_zero_param from deepspeed.runtime.zero.mics_utils import (MiCS_CommGroups, create_mics_comm_groups, scale_tensors) -from deepspeed.runtime.zero.parameter_offload import (DeepSpeedZeRoOffload, is_zero_param) +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.partition_parameters import Init, AllGatherCoalescedHandle, ZeroParamStatus from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 from deepspeed.utils import instrument_w_nvtx, log_dist diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 56cc4af19840..e9e79c2647fb 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -8,7 +8,7 @@ from collections import OrderedDict from deepspeed.utils import z3_leaf_module from deepspeed.runtime.utils import see_memory_usage -from deepspeed.runtime.zero.utils import apply_to_tensors_only +from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import * @@ -21,20 +21,6 @@ warned = False -def _apply_backward_to_tensors_only(module, functional, backward_function, outputs): - - def apply_to_tensor_fn(tensor): - # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter - touched_outputs = functional.apply(module, backward_function, tensor) - - # restore zero param attributes if those get stripped by `backward_function` - if not is_zero_param(touched_outputs) and is_zero_param(tensor): - touched_outputs.ds_param_alias = tensor - return touched_outputs - - return apply_to_tensors_only(apply_to_tensor_fn, outputs) - - #for each tensor in outputs run the forward_function and register backward_function as hook def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): if type(outputs) is tuple: @@ -94,54 +80,6 @@ def _inject_parameters(module, cls): module._parameters = new_param -class PreBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, post_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.post_backward_function = post_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.post_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - class DeepSpeedZeRoOffload(object): def __init__( @@ -341,6 +279,7 @@ def _pre_forward_module_hook(module, *args): @instrument_w_nvtx def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK FWD_MODULE_STACK.pop() if output is None: @@ -381,20 +320,49 @@ def _post_forward_module_hook(module, input, output): self.post_sub_module_forward_function(module) - def _pre_backward_module_hook(module, inputs, output): + def _bwd_hook_unexpected_inputs_msg(value): + return f"A module has unknown inputs or outputs type ({type(value)}) and the tensors embedded in it cannot be detected. " \ + "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " \ + "output tensors and therefore may not get triggered properly." - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + def _pre_backward_module_hook(module, inputs, output): - return _apply_backward_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output) + if not hasattr(module, "pre_bwd_fn"): + + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, outputs): + # Capture `module` and _run_before_backward_function + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args + + module.pre_bwd_fn = PreBackwardFunctionForModule + + return apply_to_tensors_only(module.pre_bwd_fn.apply, + output, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) #This is an alternate to doing _post_backward_module_hook #it uses tensor.register_hook instead of using torch.autograd.Function @@ -419,12 +387,44 @@ def _run_before_forward_function(input): def _post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_backward_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs) + if not hasattr(module, "post_bwd_fn"): + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args + + module.post_bwd_fn = PostBackwardFunctionModule + + return apply_to_tensors_only(module.post_bwd_fn.apply, + inputs, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) # Pre forward hook self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c55cf5bdfa17..99a9d100082b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -25,7 +25,7 @@ import deepspeed from ..utils import see_memory_usage from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.config_utils import get_config_default from deepspeed.utils import instrument_w_nvtx, logger @@ -109,12 +109,6 @@ def debug_rank0(msg: str) -> None: logger.debug(msg) -def is_zero_param(parameter): - if not torch.is_tensor(parameter): - return False - return hasattr(parameter, 'ds_id') - - def _init_external_params(module): if not hasattr(module, '_external_params'): module._external_params = {} @@ -911,7 +905,16 @@ def __init__( _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) if config_dict_or_path is not None else None if _ds_config is not None: - mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled: + # memory_efficient_linear displays numerous errors when torch.compile is enabled. + # Refer to https://github.com/pytorch/pytorch/issues/119059 for details. + # Further investigation into performance is necessary, even after resolving this issue because + # the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation. + logger.warning(f'memory_efficient_linear is disabled when torch.compile is enabled.') + mem_efficient_linear = False + else: + mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) if not dist.is_initialized(): init_distributed() diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index a3f2319b6a2c..cfeae9e7839a 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -17,6 +17,8 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.accelerator import get_accelerator +import deepspeed.runtime.compiler as compiler + import logging ENABLE_PROFILER = False @@ -175,6 +177,7 @@ def trace_prologue(self, sub_module: Module) -> None: force=True) self._invalidate_trace() + @compiler.disable def record_module(self, sub_module: Module) -> None: """adds sub module to trace""" if not self.is_record_trace(): @@ -252,6 +255,7 @@ def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None): Fetching, prefetching, and releasing parameters """ + @compiler.disable @instrument_w_nvtx @torch.no_grad() def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: @@ -272,6 +276,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: params_to_fetch = frozenset(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))) fetch_numel = sum( [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + if fetch_numel > 0: event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT self._dump_param_ids(event_name, current_submodule.id, @@ -468,6 +473,7 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: if swap_persisted_params: swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params) + @compiler.disable @instrument_w_nvtx def __release_param(self, param: Parameter) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 78eaaba59ebb..f61715bd4387 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -108,19 +108,22 @@ def isinstance_namedtuple(obj: object) -> bool: return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') +def is_zero_param(parameter): + if not torch.is_tensor(parameter): + return False + return hasattr(parameter, 'ds_id') + + def apply_to_tensors_only(function, value, warning_msg_fn=None): """ Apply `function` to every Tensor in `value`. Args: - module (torch.nn.Module): A torch module - functional (Type[torch.autograd.Function]): The function class to apply. - backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to - `functional.apply`. - outputs (Any): The output of `module`. + functional: The function class to apply. + value (Any): Target object to apply `function` to. Returns: - Any: The output of `module`. + Any: Output of `function`. """ if isinstance(value, (tuple, list)): touched_outputs = [] @@ -141,7 +144,13 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): elif isinstance(value, torch.Tensor): # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter - return function(value) + touched_output = function(value) + + # restore zero param attributes if those get stripped by `backward_function` + if not is_zero_param(touched_output) and is_zero_param(value): + touched_output.ds_param_alias = value + + return touched_output else: if not is_builtin_type(value): global warned diff --git a/tests/unit/common.py b/tests/unit/common.py index cdeca54b01ee..420db577cf09 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -113,6 +113,7 @@ class DistributedExec(ABC): set_dist_env = True requires_cuda_env = True reuse_dist_env = False + non_daemonic_procs = False _pool_cache = {} exec_timeout = DEEPSPEED_TEST_TIMEOUT @@ -145,16 +146,7 @@ def _get_fixture_kwargs(self, request, func): pass # test methods can have kwargs that are not fixtures return fixture_kwargs - def _launch_procs(self, num_procs): - # Verify we have enough accelerator devices to run this test - if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: - pytest.skip( - f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" - ) - - # Set start method to `forkserver` (or `fork`) - mp.set_start_method('forkserver', force=True) - + def _launch_daemonic_procs(self, num_procs): # Create process pool or use cached one master_port = None if self.reuse_dist_env: @@ -186,8 +178,70 @@ def _launch_procs(self, num_procs): assert len(set(skip_msgs)) == 1, "Multiple different skip messages received" pytest.skip(skip_msgs[0]) - def _dist_run(self, local_rank, num_procs, master_port): - skip_msg = '' + def _launch_non_daemonic_procs(self, num_procs): + assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes" + + master_port = get_master_port() + skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason + processes = [] + for local_rank in range(num_procs): + p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg)) + p.start() + processes.append(p) + + # Now loop and wait for a test to complete. The spin-wait here isn't a big + # deal because the number of processes will be O(#GPUs) << O(#CPUs). + any_done = False + start = time.time() + while (not any_done) and ((time.time() - start) < self.exec_timeout): + for p in processes: + if not p.is_alive(): + any_done = True + break + time.sleep(.1) # So we don't hog CPU + + # If we hit the timeout, then presume a test is hanged + if not any_done: + for p in processes: + p.terminate() + pytest.exit("Test hanged, exiting", returncode=0) + + # Wait for all other processes to complete + for p in processes: + p.join(self.exec_timeout) + + failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0] + for rank, p in failed: + # If it still hasn't terminated, kill it because it hung. + if p.exitcode is None: + p.terminate() + pytest.fail(f'Worker {rank} hung.', pytrace=False) + if p.exitcode < 0: + pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', pytrace=False) + if p.exitcode > 0: + pytest.fail(f'Worker {rank} exited with code {p.exitcode}', pytrace=False) + + if not skip_msg.empty(): + # This assumed all skip messages are the same, it may be useful to + # add a check here to assert all exit messages are equal + pytest.skip(skip_msg.get()) + + def _launch_procs(self, num_procs): + # Verify we have enough accelerator devices to run this test + if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: + pytest.skip( + f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" + ) + + # Set start method to `forkserver` (or `fork`) + mp.set_start_method('forkserver', force=True) + + if self.non_daemonic_procs: + self._launch_non_daemonic_procs(num_procs) + else: + self._launch_daemonic_procs(num_procs) + + def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""): if not dist.is_initialized(): """ Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: @@ -218,7 +272,10 @@ def _dist_run(self, local_rank, num_procs, master_port): self.run(**self._fixture_kwargs) except BaseException as e: if isinstance(e, Skipped): - skip_msg = e.msg + if self.non_daemonic_procs: + skip_msg.put(e.msg) + else: + skip_msg = e.msg else: raise e diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py new file mode 100644 index 000000000000..fbf235fb7d62 --- /dev/null +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +import deepspeed +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + + +@pytest.fixture +def base_config(): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + }, + "compile": { + "enabled": True, + "backend": "inductor" + } + } + return config_dict + + +class SmallModelWithCustomMethod(torch.nn.Module): + + def __init__(self, hidden_dim, test_value): + super(SmallModelWithCustomMethod, self).__init__() + self.fc = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v = test_value + + def forward(self, x): + return self.fc(x) + + # Custom function that is not part of DeepSpeed engine. + def get_v(self): + return self.v + + +class TestCustomMethod(DistributedTest): + world_size = 1 + non_daemonic_procs = True + + def _init_engine(self, config, test_value): + hidden_dim = 10 + model = SmallModelWithCustomMethod(hidden_dim, test_value) + engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + return engine + + def _run_model(self, engine): + train_batch_size = 1 + device = torch.device(get_accelerator().current_device_name()) + dtype = engine.module.fc.weight.dtype + hidden_dim = engine.module.fc.weight.shape[1] + x = torch.rand(train_batch_size, hidden_dim, device=device, dtype=dtype) + engine(x) + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_custom_function(self, base_config): + test_value = 10 + + engine = self._init_engine(base_config, test_value) + assert engine.module.get_v() == test_value + self._run_model(engine) + + # The model is compiled after the first run. + # Thus we make sure the custom method is still available after compilation. + assert engine.module.get_v() == test_value diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py new file mode 100644 index 000000000000..87e3c52b9e3c --- /dev/null +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum + +from unit.runtime.compile.util import compare_loss +from unit.common import DistributedTest +from unit.util import bf16_required_version_check + + +class TestZeRO(DistributedTest): + world_size = 2 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if offload_device == OffloadDeviceEnum.nvme: + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "enabled": True, + "backend": "inductor" + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + compare_loss(self, config_dict, dtype) diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py new file mode 100644 index 000000000000..351e91d2f69b --- /dev/null +++ b/tests/unit/runtime/compile/test_load_config.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from unit.simple_model import SimpleModel +import deepspeed +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +custom_backend_called = False +custom_compler_fn_called = False + +if deepspeed.is_compile_supported(): + # PyTorch v1 does not have torch.fx + def custom_backend(gm: torch.fx.GraphModule, example_inputs): + global custom_backend_called + custom_backend_called = True + return gm.forward + + def custom_compiler_fn(module: torch.nn.Module): + global custom_compler_fn_called + custom_compler_fn_called = True + return torch.compile(module) + + +@pytest.fixture +def base_config(): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + }, + "compile": { + "enabled": True, + "backend": "inductor" + } + } + return config_dict + + +class TestConfigLoad(DistributedTest): + world_size = 1 + non_daemonic_procs = True + + def _init_engine(self, config): + hidden_dim = 10 + model = SimpleModel(hidden_dim) + engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + return engine + + def _run_model(self, engine): + train_batch_size = 1 + device = torch.device(get_accelerator().current_device_name()) + dtype = engine.module.linears[0].weight.dtype + hidden_dim = engine.module.linears[0].weight.shape[1] + x = torch.rand(train_batch_size, hidden_dim, device=device, dtype=dtype) + y = torch.randn_like(x) + engine(x, y) + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_compile(self, base_config): + engine = self._init_engine(base_config) + self._run_model(engine) + assert engine.is_compiled + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_custom_backend(self, base_config): + global custom_backend_called + custom_backend_called = False + + engine = self._init_engine(base_config) + engine.set_backend(f"{__name__}.custom_backend") + self._run_model(engine) + assert custom_backend_called + + def test_compile_disabled(self, base_config): + base_config["compile"]["enabled"] = False + engine = self._init_engine(base_config) + self._run_model(engine) + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_compile_kwargs(self, base_config): + base_config["compile"]["kwargs"] = {"mode": "default"} + engine = self._init_engine(base_config) + self._run_model(engine) + assert "mode" in engine.torch_compile_kwargs + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_set_compile_kwargs(self, base_config): + engine = self._init_engine(base_config) + engine.set_torch_compile_kwargs({"mode": "default"}) + self._run_model(engine) + assert "mode" in engine.torch_compile_kwargs + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_set_compiler_fn(self, base_config): + global custom_compler_fn_called + custom_compler_fn_called = False + + engine = self._init_engine(base_config) + engine.set_compiler_fn(custom_compiler_fn) + self._run_model(engine) + assert custom_compler_fn_called diff --git a/tests/unit/runtime/compile/util.py b/tests/unit/runtime/compile/util.py new file mode 100644 index 000000000000..86eadf3f6976 --- /dev/null +++ b/tests/unit/runtime/compile/util.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +import os +import numpy as np +from copy import deepcopy + +import torch + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import GatheredParameters + +from unit.simple_model import SimpleModel +from typing import Callable, Any + + +class EnableDeterminism: + + def __init__(self, seed: int): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + self.seed = seed + local_rank + self.saved_random_state = None + self.saved_np_random_state = None + self.saved_cuda_launch_blocking = None + self.saved_cublas_workspace_config = None + self.saved_deterministic_algorithms = None + + def __enter__(self): + self.saved_random_state = random.getstate() + self.saved_np_random_state = np.random.get_state() + self.saved_acc_rng_state = get_accelerator().get_rng_state() + self.saved_cuda_launch_blocking = os.environ.get("CUDA_LAUNCH_BLOCKING", "") + self.saved_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") + self.saved_deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + + random.seed(self.seed) + np.random.seed(self.seed) + get_accelerator().manual_seed(self.seed) + get_accelerator().manual_seed_all(self.seed) + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + def __exit__(self, type, value, traceback): + random.setstate(self.saved_random_state) + np.random.set_state(self.saved_np_random_state) + get_accelerator().set_rng_state(self.saved_acc_rng_state) + os.environ["CUDA_LAUNCH_BLOCKING"] = self.saved_cuda_launch_blocking + os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.saved_cublas_workspace_config + torch.use_deterministic_algorithms(self.saved_deterministic_algorithms) + + +def enable_determinism(seed: int): + + def decorator(func: Callable) -> Callable: + + def wrapper(*args: Any, **kwargs: Any): + with EnableDeterminism(seed): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +@enable_determinism(123) +def compare_loss(self, config, dtype): + iteration = 5 + hidden_dim = 10 + RTOL = 5e-1 + ATOL = 1e-2 + + device = torch.device(get_accelerator().current_device_name()) + model = SimpleModel(hidden_dim) + + i = get_accelerator().current_device() + baseline_model = deepcopy(model) + baseline_config = deepcopy(config) + baseline_config["zero_optimization"]["stage"] = 0 + baseline_config["zero_optimization"]["offload_optimizer"] = {} + baseline_config["compile"]["enabled"] = False + baseline_engine, baseline_optimizer, _, _ = deepspeed.initialize(config=baseline_config, + model=baseline_model, + model_parameters=baseline_model.parameters()) + + if config["zero_optimization"]["stage"] == 3: + with deepspeed.zero.Init(config_dict_or_path=config): + target_model = SimpleModel(hidden_dim) + with GatheredParameters(target_model.parameters(), modifier_rank=0): + for p1, p2 in zip(target_model.parameters(), model.parameters()): + p1.data.copy_(p2.data) + else: + target_model = deepcopy(model) + + target_engine, target_optimizer, _, _ = deepspeed.initialize(config=config, + model=target_model, + model_parameters=target_model.parameters()) + + train_batch_size = config["train_micro_batch_size_per_gpu"] + + xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=dtype) for _ in range(iteration)] + ys = [torch.randn_like(x) for x in xs] + + for x, y in zip(xs, ys): + baseline_loss = baseline_engine(x, y) + target_loss = target_engine(x, y) + + assert torch.allclose(baseline_loss, target_loss, rtol=RTOL, atol=ATOL) + + baseline_engine.backward(baseline_loss) + target_engine.backward(target_loss) + + baseline_optimizer.step() + target_optimizer.step() + + with GatheredParameters(target_engine.parameters()): + for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()): + assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)