diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index aaa836bf1c31..406888bd5030 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -445,6 +445,8 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._link_all_hp_params() def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): + self.optimizer.step() + self._lazy_init_hp_params_optimizer_state() self._load_hp_checkpoint_state(checkpoint_folder) @property diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5c1202ba06ae..664ff1a89c0a 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2785,7 +2785,7 @@ def load_checkpoint(self, if self.load_universal_checkpoint(): self.optimizer.update_lp_params() if load_zero_checkpoint: - self.update_optimizer_step(step=client_states['iteration'] + 1) + self.update_optimizer_step(step=client_states['iteration']) return load_path, client_states