Skip to content

Commit

Permalink
Fixes for training models with bf16 + freshly initialized optimizer v…
Browse files Browse the repository at this point in the history
…ia `load_module_only` (#4141)

This PR makes some fixes to the case where we want to resume training
from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while
not using the old optimizer in the checkpoint or relying on its
existence at all.

in this situation, despite passing `load_module_only=True` and
`load_optimizer_states=False` to `load_checkpoint()`, the previous
behavior was that:
- `self._load_zero_checkpoint` would still be called, which attempts to
load from the (in this case, nonexistent) checkpoint files. This PR
stops this function from being called if using `load_module_only=True`
and `load_optimizer_states=False`. Alternatively, calling this function
may be alright if `"load_from_fp32_weights": true` is set in the
DeepSpeed ZeRO config (reference:
https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733)
but this parameter does not seem to be documented in the docs for ZeRO
config dicts.
- in `_load_checkpoint`, the following codeblock: 
```
if self.optimizer is not None and self.fp16_enabled():
    self.optimizer.refresh_fp32_params()
```
results in `self.optimizer.refresh_fp32_params()` being called only if
using FP16. As a result, the FP32 optimizer state is never initialized
from the 16-bit model weights. This PR removes the fp16-specific
condition.


Previously reported in:
EleutherAI/gpt-neox#947
EleutherAI/gpt-neox#843

Should also close:
#4017

Fixes: #4944 and #4017

This caused problems for a freshly-converted LLama checkpoint, which did
not contain optimizer states, when trying to train with this model as
initialization. I have confirmed the following fixes prevent this
behavior.

cc @Quentin-Anthony @zhangir-azerbayev

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2024
1 parent 740080c commit 870ae04
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,10 +2745,12 @@ def load_checkpoint(self,
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)

load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()
or self.bfloat16_enabled())
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if load_zero_checkpoint:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if load_optimizer_states and not load_module_only:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
else:
success = False
if not success:
self.optimizer._restore_from_bit16_weights()

Expand Down Expand Up @@ -2830,7 +2832,7 @@ def _load_checkpoint(self,
optim_checkpoint = None
if load_module_only:
deepspeed_states = ['module']
if self.optimizer is not None and self.fp16_enabled():
if self.optimizer is not None:
self.optimizer.refresh_fp32_params()
else:
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
Expand Down

0 comments on commit 870ae04

Please sign in to comment.