Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for training models with bf16 + freshly initialized optimizer via load_module_only #4141

Merged
merged 8 commits into from
Jan 18, 2024
Merged
10 changes: 6 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,10 +2745,12 @@ def load_checkpoint(self,
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)

load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()
or self.bfloat16_enabled())
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if load_zero_checkpoint:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if load_optimizer_states and not load_module_only:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
else:
success = False
if not success:
self.optimizer._restore_from_bit16_weights()

Expand Down Expand Up @@ -2830,7 +2832,7 @@ def _load_checkpoint(self,
optim_checkpoint = None
if load_module_only:
deepspeed_states = ['module']
if self.optimizer is not None and self.fp16_enabled():
if self.optimizer is not None:
self.optimizer.refresh_fp32_params()
else:
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
Expand Down