diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 00346317ebf1..9c0da4c2d406 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2745,10 +2745,12 @@ def load_checkpoint(self, load_module_only=load_module_only, custom_load_fn=custom_load_fn) - load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization() - or self.bfloat16_enabled()) + load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: - success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + if load_optimizer_states and not load_module_only: + success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + else: + success = False if not success: self.optimizer._restore_from_bit16_weights() @@ -2830,7 +2832,7 @@ def _load_checkpoint(self, optim_checkpoint = None if load_module_only: deepspeed_states = ['module'] - if self.optimizer is not None and self.fp16_enabled(): + if self.optimizer is not None: self.optimizer.refresh_fp32_params() else: has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()