Skip to content

Commit

Permalink
remove "fp16" argument in checkpoint_correctness_verification
Browse files Browse the repository at this point in the history
  • Loading branch information
delock committed Mar 12, 2024
1 parent 88567b3 commit c94003b
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 9 deletions.
3 changes: 1 addition & 2 deletions tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def checkpoint_correctness_verification(config_dict,
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False,
fp16=True,
train_batch=False,
base_optimizers=[None, None],
empty_tag=False,
Expand Down Expand Up @@ -245,7 +244,7 @@ def checkpoint_correctness_verification(config_dict,
load_module_only=load_module_only)

if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
compare_optimizer_states(trained_model, loaded_model, hidden_dim, dtype == torch.float16)

if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
1 change: 0 additions & 1 deletion tests/unit/checkpoint/test_latest_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test_existing_latest(self, tmpdir):
tmpdir=tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=False,
empty_tag=True,
dtype=torch.float)

Expand Down
8 changes: 4 additions & 4 deletions tests/unit/checkpoint/test_moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def test_checkpoint_moe(self, tmpdir, ep_size):
tmpdir=tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=config_dict["fp16"]["enabled"],
empty_tag=True,
base_optimizers=optimizers,
seq_dataloader=True)
seq_dataloader=True,
dtype=torch.float16)

@pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)])
def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states):
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states):
tmpdir=tmpdir,
load_optimizer_states=load_optim_states,
load_lr_scheduler_states=False,
fp16=config_dict["fp16"]["enabled"],
empty_tag=True,
base_optimizers=optimizers,
seq_dataloader=True)
seq_dataloader=True,
dtype=torch.float16)
1 change: 0 additions & 1 deletion tests/unit/checkpoint/test_other_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,4 @@ def test_checkpoint_fp32_optimizer(self, tmpdir):
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
fp16=False,
dtype=torch.float32)
1 change: 0 additions & 1 deletion tests/unit/checkpoint/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def test_checkpoint_pipe_engine(self, zero_stage, tmpdir):
models=models,
hidden_dim=models[0].hidden_dim,
tmpdir=tmpdir,
fp16=config_dict['fp16']['enabled'],
load_optimizer_states=True,
load_lr_scheduler_states=True,
train_batch=True,
Expand Down

0 comments on commit c94003b

Please sign in to comment.