From c2e3a706b5a39517e7179baa7c25c60d17a2e3c6 Mon Sep 17 00:00:00 2001 From: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Date: Fri, 16 Aug 2024 11:40:25 -0700 Subject: [PATCH] Add and Remove ZeRO 3 Hooks (#5658) Gives the ability to add and remove the forward hooks in ZeRO 3 by using a context manager. These code changes were taken from a Huggingface [PR](https://github.com/huggingface/trl/pull/1617) and integrated for direct support in DeepSpeed. This is useful in the inference case and the speedup can be observed [here](https://github.com/huggingface/trl/pull/1483). --------- Co-authored-by: root Co-authored-by: Olatunji Ruwase Co-authored-by: Heyang Qin Co-authored-by: Logan Adams --- deepspeed/runtime/zero/__init__.py | 2 + deepspeed/runtime/zero/stage3.py | 34 ++++++++++ tests/unit/runtime/zero/test_unwrap_model.py | 67 ++++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 tests/unit/runtime/zero/test_unwrap_model.py diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py index 1ccca09a9e69..23fcf9ec13fb 100644 --- a/deepspeed/runtime/zero/__init__.py +++ b/deepspeed/runtime/zero/__init__.py @@ -13,3 +13,5 @@ from .tiling import TiledLinearReturnBias from .mics import MiCS_Init + +from .stage3 import unwrap_model_for_generation diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 65d1c5ace08f..796957a4c6e5 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,6 +7,7 @@ import gc import collections from typing import Deque, Dict, Tuple +from contextlib import contextmanager from deepspeed import comm as dist from deepspeed.utils import groups @@ -69,6 +70,39 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() +@contextmanager +def unwrap_model_for_generation(model): + """ + For ZeRO-3 models, we gather the weights once to speed up generation. + """ + with GatheredParameters(model.parameters()): + # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. + + # Remove hooks + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + yield model + + # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + return + + INITIAL_MICRO_STEP_ID = -1 diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py new file mode 100644 index 000000000000..d75519b67f68 --- /dev/null +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +from deepspeed.runtime.zero import unwrap_model_for_generation +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel + +config = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 1, + "offload_param": { + "device": "cpu", + "pin_memory": True + } + } +} + +if get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} +elif get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + + +class TestUnwrapModel(DistributedTest): + # gather across more than 1 gpu + world_size = 2 + + def test(self): + + def hooks_exist(engine): + if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"): + optimizer_offload = engine.optimizer.parameter_offload + elif engine.optimizer is not None: + optimizer_offload = engine.optimizer + + hooks = 0 + for hook in optimizer_offload.forward_hooks: + hooks += 1 + if hooks > 0: + return True + return False + + model = SimpleModel(hidden_dim=100) + engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) + + with unwrap_model_for_generation(engine): + # assert no hooks + assert not hooks_exist(engine) + # assert parameters gathered + assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" + + # assert hooks + assert hooks_exist(engine)