From c3cfe96bb3374ae6d8ff200e7487a1562de43e11 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 5 Feb 2024 21:11:49 -0800 Subject: [PATCH] Enable torch.compile with ZeRO (Experimental) (#4878) This PR enables `torch.compile` with ZeRO stages 1/2/3. You need to add `compile` section in your DeepSpeed config. The fields in the section are passed to `torch.compile`. ```json "compile": { "disable": false, "backend": "inductor" } ``` To enable a custom backend, you can pass the fully qualified name of the backend function. For example, if you have a backend class `my_backend` in `my_backend.py` in the current directory, you can enable it by `"backend": "my_backend.my_backend"`. You can find an example in [a unit test](https://github.com/microsoft/DeepSpeed/blob/eb9d4e06e9596f391aea305a6a5c6ec70cc28b58/tests/unit/runtime/compile/test_config.py#L116). Currently we validated the results with Megatron-DeepSpeed. See the [example](https://github.com/microsoft/Megatron-DeepSpeed/tree/tohtana/enable_compile/examples_deepspeed/compile) for the details. NOTICE: This PR is a draft. We will need to validate the coverage and accuracy with many more examples. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt --- deepspeed/__init__.py | 1 + deepspeed/comm/torch.py | 24 +++ deepspeed/runtime/compiler.py | 166 ++++++++++++++++++ deepspeed/runtime/config.py | 3 + deepspeed/runtime/engine.py | 39 ++-- deepspeed/runtime/zero/mics.py | 3 +- deepspeed/runtime/zero/parameter_offload.py | 162 ++++++++--------- .../runtime/zero/partition_parameters.py | 19 +- .../zero/partitioned_param_coordinator.py | 6 + deepspeed/runtime/zero/utils.py | 23 ++- tests/unit/common.py | 83 +++++++-- .../runtime/compile/test_compile_wrapper.py | 79 +++++++++ .../unit/runtime/compile/test_compile_zero.py | 63 +++++++ .../unit/runtime/compile/test_load_config.py | 114 ++++++++++++ tests/unit/runtime/compile/util.py | 124 +++++++++++++ 15 files changed, 784 insertions(+), 125 deletions(-) create mode 100644 deepspeed/runtime/compiler.py create mode 100644 tests/unit/runtime/compile/test_compile_wrapper.py create mode 100644 tests/unit/runtime/compile/test_compile_zero.py create mode 100644 tests/unit/runtime/compile/test_load_config.py create mode 100644 tests/unit/runtime/compile/util.py diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 61ae016eef4b..a6feae06974b 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -39,6 +39,7 @@ from .runtime import zero from .runtime import DeepSpeedOptimizer, ZeROOptimizer +from .runtime.compiler import is_compile_supported from .pipe import PipelineModule diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 531f54d2628f..d96c63d84e60 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -8,6 +8,7 @@ from .utils import * from .backend import * from .comm import * +from ..runtime import compiler import os DS_COMM_ALL_GATHER_OFF = False @@ -120,6 +121,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.init_process_group(backend, timeout, init_method, rank, world_size) @classmethod + @compiler.disable def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -128,6 +130,7 @@ def get_all_gather_function(self): return None @classmethod + @compiler.disable def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -150,14 +153,17 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' + @compiler.disable def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) + @compiler.disable def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) + @compiler.disable def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -168,6 +174,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) + @compiler.disable def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -175,6 +182,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) + @compiler.disable def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -187,6 +195,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) + @compiler.disable def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -195,6 +204,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) + @compiler.disable def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -203,6 +213,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) + @compiler.disable def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -210,6 +221,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) + @compiler.disable def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -227,6 +239,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass + @compiler.disable def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -250,6 +263,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() + @compiler.disable def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -263,6 +277,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass + @compiler.disable def all_to_all_single(self, output, input, @@ -277,21 +292,27 @@ def all_to_all_single(self, group=group, async_op=async_op) + @compiler.disable def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) + @compiler.disable def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) + @compiler.disable def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) + @compiler.disable def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) + @compiler.disable def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) + @compiler.disable def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -299,6 +320,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) + @compiler.disable def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -306,11 +328,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) + @compiler.disable def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) + @compiler.disable def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py new file mode 100644 index 000000000000..603f563fca60 --- /dev/null +++ b/deepspeed/runtime/compiler.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Union, Callable, Dict, Any +import importlib +import torch +from ..pydantic_v1 import validator +from .config_utils import DeepSpeedConfigModel + +COMPILE_CONFIG = "compile" + + +def is_compile_supported(): + return hasattr(torch, "compile") + + +def disable(func): + if is_compile_supported(): + return torch.compiler.disable(func) + return func + + +def get_compile_config(param_dict): + if COMPILE_CONFIG in param_dict: + compile_config_dict = param_dict[COMPILE_CONFIG] + else: + compile_config_dict = {} + return CompileConfig(**compile_config_dict) + + +def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]: + if isinstance(backend, Callable): + return backend + + elif isinstance(backend, str): + if backend in torch._dynamo.list_backends(): + return backend + + # Get module name from backend name + module_name = '.'.join(backend.split('.')[:-1]) + fn_name = backend.split('.')[-1] + + try: + module = importlib.import_module(module_name) + backend_fn = getattr(module, fn_name) + except ImportError: + raise ValueError( + f"The backend {backend} is not in the list of available backends and could not be imported.") + return backend_fn + + raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}") + + +class CompileConfig(DeepSpeedConfigModel): + """ + [EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings. + Please be aware that these features and API designs are experimental and subject to change. + """ + + enabled: bool = False + """ + Enable torch.compile when True. + """ + + backend: str = "inductor" + """ + Passed to `backend` argument of torch.compile. + If the given value is not in torch._dynamo.list_backends(), + DeepSpeed attempts to import and instantiate the module with the given name. + """ + + kwargs: Dict[str, Any] = {} + """ + Passed to `kwargs` argument of torch.compile. + """ + + @validator("enabled") + def validate_enabled(cls, field_value, values): + if field_value and not is_compile_supported(): + raise ValueError("torch.compile is not supported on this version of PyTorch.") + return field_value + + +class CompiledModuleWrapper(torch.nn.Module): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + super().__init__() + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + modules = self.__dict__.get('_modules') + modules['wrapped'] = module + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def __getattr__(self, name): + return getattr(self.__dict__['wrapped'], name) + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 80754df50c20..20fbf475ca90 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -31,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config from ..inference.config import WeightQuantConfig +from .compiler import get_compile_config from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -899,6 +900,8 @@ def _initialize_params(self, param_dict): self.weight_quantization_config = WeightQuantConfig( **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None + self.compile_config = get_compile_config(param_dict) + def _batch_assertion(self): train_batch = self.train_batch_size diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 58dac4e44b77..f0602813f3ab 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -90,6 +90,7 @@ from .pipe.module import PipelineModule from .utils import get_ma_status +from .compiler import CompiledModuleWrapper from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE @@ -179,21 +180,19 @@ def __init__(self, enable_micro_timers, enable_global_timers): class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" - def __init__( - self, - args, - model, - optimizer=None, - model_parameters=None, - training_data=None, - lr_scheduler=None, - mpu=None, - dist_init_required=None, - collate_fn=None, - config=None, - config_class=None, - dont_change_device=False, - ): + def __init__(self, + args, + model, + optimizer=None, + model_parameters=None, + training_data=None, + lr_scheduler=None, + mpu=None, + dist_init_required=None, + collate_fn=None, + config=None, + config_class=None, + dont_change_device=False): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device self.client_optimizer = optimizer @@ -363,6 +362,9 @@ def __init__( self.flatten = _flatten_dense_tensors self.unflatten = _unflatten_dense_tensors + if self._config.compile_config.enabled: + self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config)) + def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() @@ -467,6 +469,13 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) + elif isinstance(_module, CompiledModuleWrapper): + try: + return getattr(_module, name) + except AttributeError: + raise AttributeError( + f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" + ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index 05e59eccfdae..1e5c9396be1d 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -12,8 +12,9 @@ import deepspeed import torch from deepspeed import comm as dist +from deepspeed.runtime.zero.utils import is_zero_param from deepspeed.runtime.zero.mics_utils import (MiCS_CommGroups, create_mics_comm_groups, scale_tensors) -from deepspeed.runtime.zero.parameter_offload import (DeepSpeedZeRoOffload, is_zero_param) +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.partition_parameters import Init, AllGatherCoalescedHandle, ZeroParamStatus from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 from deepspeed.utils import instrument_w_nvtx, log_dist diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 56cc4af19840..e9e79c2647fb 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -8,7 +8,7 @@ from collections import OrderedDict from deepspeed.utils import z3_leaf_module from deepspeed.runtime.utils import see_memory_usage -from deepspeed.runtime.zero.utils import apply_to_tensors_only +from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import * @@ -21,20 +21,6 @@ warned = False -def _apply_backward_to_tensors_only(module, functional, backward_function, outputs): - - def apply_to_tensor_fn(tensor): - # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter - touched_outputs = functional.apply(module, backward_function, tensor) - - # restore zero param attributes if those get stripped by `backward_function` - if not is_zero_param(touched_outputs) and is_zero_param(tensor): - touched_outputs.ds_param_alias = tensor - return touched_outputs - - return apply_to_tensors_only(apply_to_tensor_fn, outputs) - - #for each tensor in outputs run the forward_function and register backward_function as hook def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): if type(outputs) is tuple: @@ -94,54 +80,6 @@ def _inject_parameters(module, cls): module._parameters = new_param -class PreBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, post_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.post_backward_function = post_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.post_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - class DeepSpeedZeRoOffload(object): def __init__( @@ -341,6 +279,7 @@ def _pre_forward_module_hook(module, *args): @instrument_w_nvtx def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK FWD_MODULE_STACK.pop() if output is None: @@ -381,20 +320,49 @@ def _post_forward_module_hook(module, input, output): self.post_sub_module_forward_function(module) - def _pre_backward_module_hook(module, inputs, output): + def _bwd_hook_unexpected_inputs_msg(value): + return f"A module has unknown inputs or outputs type ({type(value)}) and the tensors embedded in it cannot be detected. " \ + "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " \ + "output tensors and therefore may not get triggered properly." - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + def _pre_backward_module_hook(module, inputs, output): - return _apply_backward_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output) + if not hasattr(module, "pre_bwd_fn"): + + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, outputs): + # Capture `module` and _run_before_backward_function + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args + + module.pre_bwd_fn = PreBackwardFunctionForModule + + return apply_to_tensors_only(module.pre_bwd_fn.apply, + output, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) #This is an alternate to doing _post_backward_module_hook #it uses tensor.register_hook instead of using torch.autograd.Function @@ -419,12 +387,44 @@ def _run_before_forward_function(input): def _post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_backward_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs) + if not hasattr(module, "post_bwd_fn"): + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args + + module.post_bwd_fn = PostBackwardFunctionModule + + return apply_to_tensors_only(module.post_bwd_fn.apply, + inputs, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) # Pre forward hook self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c55cf5bdfa17..99a9d100082b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -25,7 +25,7 @@ import deepspeed from ..utils import see_memory_usage from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.config_utils import get_config_default from deepspeed.utils import instrument_w_nvtx, logger @@ -109,12 +109,6 @@ def debug_rank0(msg: str) -> None: logger.debug(msg) -def is_zero_param(parameter): - if not torch.is_tensor(parameter): - return False - return hasattr(parameter, 'ds_id') - - def _init_external_params(module): if not hasattr(module, '_external_params'): module._external_params = {} @@ -911,7 +905,16 @@ def __init__( _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) if config_dict_or_path is not None else None if _ds_config is not None: - mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled: + # memory_efficient_linear displays numerous errors when torch.compile is enabled. + # Refer to https://github.com/pytorch/pytorch/issues/119059 for details. + # Further investigation into performance is necessary, even after resolving this issue because + # the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation. + logger.warning(f'memory_efficient_linear is disabled when torch.compile is enabled.') + mem_efficient_linear = False + else: + mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) if not dist.is_initialized(): init_distributed() diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index a3f2319b6a2c..cfeae9e7839a 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -17,6 +17,8 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.accelerator import get_accelerator +import deepspeed.runtime.compiler as compiler + import logging ENABLE_PROFILER = False @@ -175,6 +177,7 @@ def trace_prologue(self, sub_module: Module) -> None: force=True) self._invalidate_trace() + @compiler.disable def record_module(self, sub_module: Module) -> None: """adds sub module to trace""" if not self.is_record_trace(): @@ -252,6 +255,7 @@ def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None): Fetching, prefetching, and releasing parameters """ + @compiler.disable @instrument_w_nvtx @torch.no_grad() def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: @@ -272,6 +276,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: params_to_fetch = frozenset(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))) fetch_numel = sum( [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + if fetch_numel > 0: event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT self._dump_param_ids(event_name, current_submodule.id, @@ -468,6 +473,7 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: if swap_persisted_params: swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params) + @compiler.disable @instrument_w_nvtx def __release_param(self, param: Parameter) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 78eaaba59ebb..f61715bd4387 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -108,19 +108,22 @@ def isinstance_namedtuple(obj: object) -> bool: return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') +def is_zero_param(parameter): + if not torch.is_tensor(parameter): + return False + return hasattr(parameter, 'ds_id') + + def apply_to_tensors_only(function, value, warning_msg_fn=None): """ Apply `function` to every Tensor in `value`. Args: - module (torch.nn.Module): A torch module - functional (Type[torch.autograd.Function]): The function class to apply. - backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to - `functional.apply`. - outputs (Any): The output of `module`. + functional: The function class to apply. + value (Any): Target object to apply `function` to. Returns: - Any: The output of `module`. + Any: Output of `function`. """ if isinstance(value, (tuple, list)): touched_outputs = [] @@ -141,7 +144,13 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): elif isinstance(value, torch.Tensor): # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter - return function(value) + touched_output = function(value) + + # restore zero param attributes if those get stripped by `backward_function` + if not is_zero_param(touched_output) and is_zero_param(value): + touched_output.ds_param_alias = value + + return touched_output else: if not is_builtin_type(value): global warned diff --git a/tests/unit/common.py b/tests/unit/common.py index cdeca54b01ee..420db577cf09 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -113,6 +113,7 @@ class DistributedExec(ABC): set_dist_env = True requires_cuda_env = True reuse_dist_env = False + non_daemonic_procs = False _pool_cache = {} exec_timeout = DEEPSPEED_TEST_TIMEOUT @@ -145,16 +146,7 @@ def _get_fixture_kwargs(self, request, func): pass # test methods can have kwargs that are not fixtures return fixture_kwargs - def _launch_procs(self, num_procs): - # Verify we have enough accelerator devices to run this test - if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: - pytest.skip( - f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" - ) - - # Set start method to `forkserver` (or `fork`) - mp.set_start_method('forkserver', force=True) - + def _launch_daemonic_procs(self, num_procs): # Create process pool or use cached one master_port = None if self.reuse_dist_env: @@ -186,8 +178,70 @@ def _launch_procs(self, num_procs): assert len(set(skip_msgs)) == 1, "Multiple different skip messages received" pytest.skip(skip_msgs[0]) - def _dist_run(self, local_rank, num_procs, master_port): - skip_msg = '' + def _launch_non_daemonic_procs(self, num_procs): + assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes" + + master_port = get_master_port() + skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason + processes = [] + for local_rank in range(num_procs): + p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg)) + p.start() + processes.append(p) + + # Now loop and wait for a test to complete. The spin-wait here isn't a big + # deal because the number of processes will be O(#GPUs) << O(#CPUs). + any_done = False + start = time.time() + while (not any_done) and ((time.time() - start) < self.exec_timeout): + for p in processes: + if not p.is_alive(): + any_done = True + break + time.sleep(.1) # So we don't hog CPU + + # If we hit the timeout, then presume a test is hanged + if not any_done: + for p in processes: + p.terminate() + pytest.exit("Test hanged, exiting", returncode=0) + + # Wait for all other processes to complete + for p in processes: + p.join(self.exec_timeout) + + failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0] + for rank, p in failed: + # If it still hasn't terminated, kill it because it hung. + if p.exitcode is None: + p.terminate() + pytest.fail(f'Worker {rank} hung.', pytrace=False) + if p.exitcode < 0: + pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', pytrace=False) + if p.exitcode > 0: + pytest.fail(f'Worker {rank} exited with code {p.exitcode}', pytrace=False) + + if not skip_msg.empty(): + # This assumed all skip messages are the same, it may be useful to + # add a check here to assert all exit messages are equal + pytest.skip(skip_msg.get()) + + def _launch_procs(self, num_procs): + # Verify we have enough accelerator devices to run this test + if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: + pytest.skip( + f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" + ) + + # Set start method to `forkserver` (or `fork`) + mp.set_start_method('forkserver', force=True) + + if self.non_daemonic_procs: + self._launch_non_daemonic_procs(num_procs) + else: + self._launch_daemonic_procs(num_procs) + + def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""): if not dist.is_initialized(): """ Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: @@ -218,7 +272,10 @@ def _dist_run(self, local_rank, num_procs, master_port): self.run(**self._fixture_kwargs) except BaseException as e: if isinstance(e, Skipped): - skip_msg = e.msg + if self.non_daemonic_procs: + skip_msg.put(e.msg) + else: + skip_msg = e.msg else: raise e diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py new file mode 100644 index 000000000000..fbf235fb7d62 --- /dev/null +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +import deepspeed +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + + +@pytest.fixture +def base_config(): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + }, + "compile": { + "enabled": True, + "backend": "inductor" + } + } + return config_dict + + +class SmallModelWithCustomMethod(torch.nn.Module): + + def __init__(self, hidden_dim, test_value): + super(SmallModelWithCustomMethod, self).__init__() + self.fc = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v = test_value + + def forward(self, x): + return self.fc(x) + + # Custom function that is not part of DeepSpeed engine. + def get_v(self): + return self.v + + +class TestCustomMethod(DistributedTest): + world_size = 1 + non_daemonic_procs = True + + def _init_engine(self, config, test_value): + hidden_dim = 10 + model = SmallModelWithCustomMethod(hidden_dim, test_value) + engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + return engine + + def _run_model(self, engine): + train_batch_size = 1 + device = torch.device(get_accelerator().current_device_name()) + dtype = engine.module.fc.weight.dtype + hidden_dim = engine.module.fc.weight.shape[1] + x = torch.rand(train_batch_size, hidden_dim, device=device, dtype=dtype) + engine(x) + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_custom_function(self, base_config): + test_value = 10 + + engine = self._init_engine(base_config, test_value) + assert engine.module.get_v() == test_value + self._run_model(engine) + + # The model is compiled after the first run. + # Thus we make sure the custom method is still available after compilation. + assert engine.module.get_v() == test_value diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py new file mode 100644 index 000000000000..87e3c52b9e3c --- /dev/null +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum + +from unit.runtime.compile.util import compare_loss +from unit.common import DistributedTest +from unit.util import bf16_required_version_check + + +class TestZeRO(DistributedTest): + world_size = 2 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if offload_device == OffloadDeviceEnum.nvme: + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "enabled": True, + "backend": "inductor" + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + compare_loss(self, config_dict, dtype) diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py new file mode 100644 index 000000000000..351e91d2f69b --- /dev/null +++ b/tests/unit/runtime/compile/test_load_config.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from unit.simple_model import SimpleModel +import deepspeed +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +custom_backend_called = False +custom_compler_fn_called = False + +if deepspeed.is_compile_supported(): + # PyTorch v1 does not have torch.fx + def custom_backend(gm: torch.fx.GraphModule, example_inputs): + global custom_backend_called + custom_backend_called = True + return gm.forward + + def custom_compiler_fn(module: torch.nn.Module): + global custom_compler_fn_called + custom_compler_fn_called = True + return torch.compile(module) + + +@pytest.fixture +def base_config(): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + }, + "compile": { + "enabled": True, + "backend": "inductor" + } + } + return config_dict + + +class TestConfigLoad(DistributedTest): + world_size = 1 + non_daemonic_procs = True + + def _init_engine(self, config): + hidden_dim = 10 + model = SimpleModel(hidden_dim) + engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + return engine + + def _run_model(self, engine): + train_batch_size = 1 + device = torch.device(get_accelerator().current_device_name()) + dtype = engine.module.linears[0].weight.dtype + hidden_dim = engine.module.linears[0].weight.shape[1] + x = torch.rand(train_batch_size, hidden_dim, device=device, dtype=dtype) + y = torch.randn_like(x) + engine(x, y) + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_compile(self, base_config): + engine = self._init_engine(base_config) + self._run_model(engine) + assert engine.is_compiled + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_custom_backend(self, base_config): + global custom_backend_called + custom_backend_called = False + + engine = self._init_engine(base_config) + engine.set_backend(f"{__name__}.custom_backend") + self._run_model(engine) + assert custom_backend_called + + def test_compile_disabled(self, base_config): + base_config["compile"]["enabled"] = False + engine = self._init_engine(base_config) + self._run_model(engine) + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_compile_kwargs(self, base_config): + base_config["compile"]["kwargs"] = {"mode": "default"} + engine = self._init_engine(base_config) + self._run_model(engine) + assert "mode" in engine.torch_compile_kwargs + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_set_compile_kwargs(self, base_config): + engine = self._init_engine(base_config) + engine.set_torch_compile_kwargs({"mode": "default"}) + self._run_model(engine) + assert "mode" in engine.torch_compile_kwargs + + @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") + def test_set_compiler_fn(self, base_config): + global custom_compler_fn_called + custom_compler_fn_called = False + + engine = self._init_engine(base_config) + engine.set_compiler_fn(custom_compiler_fn) + self._run_model(engine) + assert custom_compler_fn_called diff --git a/tests/unit/runtime/compile/util.py b/tests/unit/runtime/compile/util.py new file mode 100644 index 000000000000..86eadf3f6976 --- /dev/null +++ b/tests/unit/runtime/compile/util.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +import os +import numpy as np +from copy import deepcopy + +import torch + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import GatheredParameters + +from unit.simple_model import SimpleModel +from typing import Callable, Any + + +class EnableDeterminism: + + def __init__(self, seed: int): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + self.seed = seed + local_rank + self.saved_random_state = None + self.saved_np_random_state = None + self.saved_cuda_launch_blocking = None + self.saved_cublas_workspace_config = None + self.saved_deterministic_algorithms = None + + def __enter__(self): + self.saved_random_state = random.getstate() + self.saved_np_random_state = np.random.get_state() + self.saved_acc_rng_state = get_accelerator().get_rng_state() + self.saved_cuda_launch_blocking = os.environ.get("CUDA_LAUNCH_BLOCKING", "") + self.saved_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") + self.saved_deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + + random.seed(self.seed) + np.random.seed(self.seed) + get_accelerator().manual_seed(self.seed) + get_accelerator().manual_seed_all(self.seed) + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + def __exit__(self, type, value, traceback): + random.setstate(self.saved_random_state) + np.random.set_state(self.saved_np_random_state) + get_accelerator().set_rng_state(self.saved_acc_rng_state) + os.environ["CUDA_LAUNCH_BLOCKING"] = self.saved_cuda_launch_blocking + os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.saved_cublas_workspace_config + torch.use_deterministic_algorithms(self.saved_deterministic_algorithms) + + +def enable_determinism(seed: int): + + def decorator(func: Callable) -> Callable: + + def wrapper(*args: Any, **kwargs: Any): + with EnableDeterminism(seed): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +@enable_determinism(123) +def compare_loss(self, config, dtype): + iteration = 5 + hidden_dim = 10 + RTOL = 5e-1 + ATOL = 1e-2 + + device = torch.device(get_accelerator().current_device_name()) + model = SimpleModel(hidden_dim) + + i = get_accelerator().current_device() + baseline_model = deepcopy(model) + baseline_config = deepcopy(config) + baseline_config["zero_optimization"]["stage"] = 0 + baseline_config["zero_optimization"]["offload_optimizer"] = {} + baseline_config["compile"]["enabled"] = False + baseline_engine, baseline_optimizer, _, _ = deepspeed.initialize(config=baseline_config, + model=baseline_model, + model_parameters=baseline_model.parameters()) + + if config["zero_optimization"]["stage"] == 3: + with deepspeed.zero.Init(config_dict_or_path=config): + target_model = SimpleModel(hidden_dim) + with GatheredParameters(target_model.parameters(), modifier_rank=0): + for p1, p2 in zip(target_model.parameters(), model.parameters()): + p1.data.copy_(p2.data) + else: + target_model = deepcopy(model) + + target_engine, target_optimizer, _, _ = deepspeed.initialize(config=config, + model=target_model, + model_parameters=target_model.parameters()) + + train_batch_size = config["train_micro_batch_size_per_gpu"] + + xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=dtype) for _ in range(iteration)] + ys = [torch.randn_like(x) for x in xs] + + for x, y in zip(xs, ys): + baseline_loss = baseline_engine(x, y) + target_loss = target_engine(x, y) + + assert torch.allclose(baseline_loss, target_loss, rtol=RTOL, atol=ATOL) + + baseline_engine.backward(baseline_loss) + target_engine.backward(target_loss) + + baseline_optimizer.step() + target_optimizer.step() + + with GatheredParameters(target_engine.parameters()): + for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()): + assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)