Skip to content

Commit

Permalink
MOE: Fix save checkpoint when TP > 1 (#5157)
Browse files Browse the repository at this point in the history
When using MOE, currently, only mp_rank_00_model_states.pt is saved.
This fails when using TP > 1.
Fix it by saving all required mp_rank_xx_model_states.pt files.

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland and mosheisland authored Feb 21, 2024
1 parent 7f0950f commit a84d07c
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3224,22 +3224,21 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
# In the case of E + D parallelism, only the
# first expert parallel group should save the expert weights
# since each expert parallel group is a copy of the model's experts
if exp_dp_rank != 0:
return

# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
}
# TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
self.checkpoint_engine.save(optimizer_state, file_path)

# get non-moe parameters
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))

if expp_rank == 0:
if exp_dp_rank == 0:
# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
}
# TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
self.checkpoint_engine.save(optimizer_state, file_path)

# Load flow uses below saved file for model parameters, RNG and more
if groups._get_data_parallel_rank() == 0:
# get non-moe parameters
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))

# TODO: update num experts info,.. in checkpoint
state = {
'module':
Expand Down

0 comments on commit a84d07c

Please sign in to comment.