diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 9c7a84d4841e..550af8fac057 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -365,7 +365,8 @@ def load_state_dict(self, state_dict_list, checkpoint_folder, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + load_serial=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b4c8ef56c701..8a8193ddd8f5 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3287,7 +3287,7 @@ def _get_zero_param_shapes(self): # if we don't use it, we get parameters ordered incorrectly if hasattr(self.optimizer, "round_robin_bit16_groups"): bit16_groups = self.optimizer.round_robin_bit16_groups - elif self.bfloat16_enabled() and not self.zero_optimization(): + elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): bit16_groups = self.optimizer.bf16_groups else: bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(