forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve universal checkpoint (microsoft#5289)
This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In microsoft#5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. microsoft@c8c0498 fixes this issue.
- Loading branch information
Showing
16 changed files
with
359 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import os | ||
import torch | ||
|
||
from deepspeed.utils import logger | ||
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states | ||
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank | ||
|
||
|
||
class DeepSpeedOptimizer(object): | ||
pass | ||
|
||
|
||
class ZeROOptimizer(DeepSpeedOptimizer): | ||
|
||
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: | ||
checkpoint_dir = os.path.join(checkpoint_dir, "zero") | ||
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") | ||
assert os.path.isfile( | ||
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' | ||
optim_sd = torch.load(optim_state_path) | ||
|
||
self._load_global_state(optim_sd) | ||
|
||
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) | ||
if self.mpu is None: | ||
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.") | ||
tp_world_size = 1 | ||
else: | ||
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ | ||
else self.mpu.get_tensor_model_parallel_world_size() | ||
|
||
for i, (param_group, | ||
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): | ||
# We have an assumption that all params in the same param_group have the same keys | ||
opt_keys = set() | ||
steps = [] | ||
|
||
lp_groups = getattr(self, lp_groups_name) | ||
for lp in lp_groups[i]: | ||
if lp._hp_mapping is not None: | ||
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") | ||
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, | ||
tp_world_size) | ||
for key in lp._hp_mapping.get_optim_state_keys(): | ||
opt_keys.add(key) | ||
steps.append(step) | ||
|
||
hp_param = param_group['params'][0] | ||
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal" | ||
if steps[0] is not None: | ||
self.optimizer.state[hp_param]['step'] = steps[0] | ||
|
||
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) | ||
|
||
for key, value in loaded_param_group.items(): | ||
if key == 'params': | ||
continue | ||
param_group[key] = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.