Skip to content

Commit

Permalink
Smangrul/fix failing ds ci tests (huggingface#27358)
Browse files Browse the repository at this point in the history
* fix failing DeepSpeed CI tests due to `safetensors` being default

* debug

* remove debug statements

* resolve comments

* Update test_deepspeed.py
  • Loading branch information
pacman100 authored Nov 9, 2023
1 parent ced9fd8 commit 7ecd229
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
slow,
)
from transformers.trainer_utils import get_last_checkpoint, set_seed
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_gpu_available


if is_torch_available():
Expand Down Expand Up @@ -565,8 +565,7 @@ def test_gradient_accumulation(self, stage, dtype):

def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints

file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]

if stage == ZERO2:
ds_file_list = ["mp_rank_00_model_states.pt"]
Expand All @@ -581,7 +580,6 @@ def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtyp
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")

# common files
for filename in file_list:
path = os.path.join(checkpoint, filename)
Expand Down

0 comments on commit 7ecd229

Please sign in to comment.