diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 992d7877c179..9a2b943b0992 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2542,7 +2542,7 @@ def all_gather_scalar(self, value, dp_group): return tensor_list def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): - sd = self.module.state_dict(destination, prefix, keep_vars) + sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # Remove frozen parameter weights from state_dict if specified if exclude_frozen_parameters: