diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 946d748a9dee..3c638802c3bd 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -463,7 +463,7 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): else: tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() - #tp_world_size = self.mpu.get_slice_parallel_world_size() + for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bf16_groups[i]: