-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Gives the ability to add and remove the forward hooks in ZeRO 3 by using a context manager. These code changes were taken from a Huggingface [PR](huggingface/trl#1617) and integrated for direct support in DeepSpeed. This is useful in the inference case and the speedup can be observed [here](huggingface/trl#1483). --------- Co-authored-by: root <root@deepspeed-c000004.2d1icxc5dsxehnpuwt3ifc34ph.gvxx.internal.cloudapp.net> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Heyang Qin <[email protected]> Co-authored-by: Logan Adams <[email protected]>
- Loading branch information
1 parent
1ab1928
commit c2e3a70
Showing
3 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import deepspeed | ||
from deepspeed.runtime.zero import unwrap_model_for_generation | ||
from deepspeed.accelerator import get_accelerator | ||
|
||
from unit.common import DistributedTest | ||
from unit.simple_model import SimpleModel | ||
|
||
config = { | ||
"train_batch_size": 2, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.00015 | ||
} | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"stage3_param_persistence_threshold": 1, | ||
"offload_param": { | ||
"device": "cpu", | ||
"pin_memory": True | ||
} | ||
} | ||
} | ||
|
||
if get_accelerator().is_fp16_supported(): | ||
config["fp16"] = {"enabled": True, "loss_scale": 138.} | ||
elif get_accelerator().is_bf16_supported(): | ||
config["bf16"] = {"enabled": True} | ||
|
||
|
||
class TestUnwrapModel(DistributedTest): | ||
# gather across more than 1 gpu | ||
world_size = 2 | ||
|
||
def test(self): | ||
|
||
def hooks_exist(engine): | ||
if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"): | ||
optimizer_offload = engine.optimizer.parameter_offload | ||
elif engine.optimizer is not None: | ||
optimizer_offload = engine.optimizer | ||
|
||
hooks = 0 | ||
for hook in optimizer_offload.forward_hooks: | ||
hooks += 1 | ||
if hooks > 0: | ||
return True | ||
return False | ||
|
||
model = SimpleModel(hidden_dim=100) | ||
engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) | ||
|
||
with unwrap_model_for_generation(engine): | ||
# assert no hooks | ||
assert not hooks_exist(engine) | ||
# assert parameters gathered | ||
assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" | ||
|
||
# assert hooks | ||
assert hooks_exist(engine) |