From 0254fbfa5dd81880455fd25bdd150e6e3d2d93d7 Mon Sep 17 00:00:00 2001 From: kazuki Date: Mon, 2 Oct 2023 00:19:00 +0900 Subject: [PATCH 1/4] fix: bf16 optimizer if condition --- deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 22d7c882eb1a..1a01aaf2eaaf 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3279,7 +3279,7 @@ def _get_zero_param_shapes(self): # if we don't use it, we get parameters ordered incorrectly if hasattr(self.optimizer, "round_robin_bit16_groups"): bit16_groups = self.optimizer.round_robin_bit16_groups - elif self.bfloat16_enabled() and not self.zero_optimization(): + elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): bit16_groups = self.optimizer.bf16_groups else: bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( From c03a6cff57fbf588107e0356aed57e4b6f85f7de Mon Sep 17 00:00:00 2001 From: kazuki Date: Mon, 2 Oct 2023 01:24:09 +0900 Subject: [PATCH 2/4] fix: unexpected keyword argument 'load_serial' --- deepspeed/runtime/engine.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1a01aaf2eaaf..629d8955957d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2868,11 +2868,17 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): if zero_sd_list is None: return False - self.optimizer.load_state_dict(state_dict_list=zero_sd_list, - load_optimizer_states=load_optimizer_states, - load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder, - load_serial=load_serial) + if self.zero_optimization_stage() == 3: + self.optimizer.load_state_dict(state_dict_list=zero_sd_list, + load_optimizer_states=load_optimizer_states, + load_from_fp32_weights=self.zero_load_from_fp32_weights(), + checkpoint_folder=checkpoint_folder, + load_serial=load_serial) + else: + self.optimizer.load_state_dict(state_dict_list=zero_sd_list, + load_optimizer_states=load_optimizer_states, + load_from_fp32_weights=self.zero_load_from_fp32_weights(), + checkpoint_folder=checkpoint_folder) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') From 4422e131c0497c1a04b72e7156cb7ab51ab9de7b Mon Sep 17 00:00:00 2001 From: kazuki Date: Wed, 4 Oct 2023 01:01:02 +0900 Subject: [PATCH 3/4] fix: add load_serial arg to bf16_optimizer --- deepspeed/runtime/bf16_optimizer.py | 3 ++- deepspeed/runtime/engine.py | 16 +++++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 9c7a84d4841e..550af8fac057 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -365,7 +365,8 @@ def load_state_dict(self, state_dict_list, checkpoint_folder, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + load_serial=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 629d8955957d..ce2e08c2b271 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2868,17 +2868,11 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): if zero_sd_list is None: return False - if self.zero_optimization_stage() == 3: - self.optimizer.load_state_dict(state_dict_list=zero_sd_list, - load_optimizer_states=load_optimizer_states, - load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder, - load_serial=load_serial) - else: - self.optimizer.load_state_dict(state_dict_list=zero_sd_list, - load_optimizer_states=load_optimizer_states, - load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder) + self.optimizer.load_state_dict(state_dict_list=zero_sd_list, + load_optimizer_states=load_optimizer_states, + load_from_fp32_weights=self.zero_load_from_fp32_weights(), + checkpoint_folder=checkpoint_folder, + load_serial=load_serial) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') From ab6a8c7b2222d5f12886d72c030ecd114e1b8ddf Mon Sep 17 00:00:00 2001 From: kazuki Date: Wed, 4 Oct 2023 01:03:16 +0900 Subject: [PATCH 4/4] style: fix indentation --- deepspeed/runtime/engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ce2e08c2b271..1a01aaf2eaaf 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2869,10 +2869,10 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): return False self.optimizer.load_state_dict(state_dict_list=zero_sd_list, - load_optimizer_states=load_optimizer_states, - load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder, - load_serial=load_serial) + load_optimizer_states=load_optimizer_states, + load_from_fp32_weights=self.zero_load_from_fp32_weights(), + checkpoint_folder=checkpoint_folder, + load_serial=load_serial) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')