diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 965b446163ec..325188f02931 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -540,6 +540,11 @@ def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups + @property + def state(self): + """Forward the wrapped optimizer's states.""" + return self.optimizer.state + def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True)