Skip to content

Commit

Permalink
Add and Remove ZeRO 3 Hooks (#5658)
Browse files Browse the repository at this point in the history
Gives the ability to add and remove the forward hooks in ZeRO 3 by using
a context manager. These code changes were taken from a Huggingface
[PR](huggingface/trl#1617) and integrated for
direct support in DeepSpeed.

This is useful in the inference case and the speedup can be observed
[here](huggingface/trl#1483).

---------

Co-authored-by: root <root@deepspeed-c000004.2d1icxc5dsxehnpuwt3ifc34ph.gvxx.internal.cloudapp.net>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Heyang Qin <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
5 people authored Aug 16, 2024
1 parent 1ab1928 commit c2e3a70
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/zero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .tiling import TiledLinearReturnBias

from .mics import MiCS_Init

from .stage3 import unwrap_model_for_generation
34 changes: 34 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gc
import collections
from typing import Deque, Dict, Tuple
from contextlib import contextmanager
from deepspeed import comm as dist
from deepspeed.utils import groups

Expand Down Expand Up @@ -69,6 +70,39 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()


@contextmanager
def unwrap_model_for_generation(model):
"""
For ZeRO-3 models, we gather the weights once to speed up generation.
"""
with GatheredParameters(model.parameters()):
# Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.

# Remove hooks
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer

for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
hook.remove()

optimizer_offload.forward_hooks = []
optimizer_offload.backward_hooks = []

yield model

# Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
return


INITIAL_MICRO_STEP_ID = -1


Expand Down
67 changes: 67 additions & 0 deletions tests/unit/runtime/zero/test_unwrap_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
from deepspeed.runtime.zero import unwrap_model_for_generation
from deepspeed.accelerator import get_accelerator

from unit.common import DistributedTest
from unit.simple_model import SimpleModel

config = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": 1,
"offload_param": {
"device": "cpu",
"pin_memory": True
}
}
}

if get_accelerator().is_fp16_supported():
config["fp16"] = {"enabled": True, "loss_scale": 138.}
elif get_accelerator().is_bf16_supported():
config["bf16"] = {"enabled": True}


class TestUnwrapModel(DistributedTest):
# gather across more than 1 gpu
world_size = 2

def test(self):

def hooks_exist(engine):
if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"):
optimizer_offload = engine.optimizer.parameter_offload
elif engine.optimizer is not None:
optimizer_offload = engine.optimizer

hooks = 0
for hook in optimizer_offload.forward_hooks:
hooks += 1
if hooks > 0:
return True
return False

model = SimpleModel(hidden_dim=100)
engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config)

with unwrap_model_for_generation(engine):
# assert no hooks
assert not hooks_exist(engine)
# assert parameters gathered
assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"

# assert hooks
assert hooks_exist(engine)

0 comments on commit c2e3a70

Please sign in to comment.