diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index da50cfd830ec..08fa1eb671bd 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -164,7 +164,6 @@ def checkpoint_correctness_verification(config_dict, tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False, - fp16=True, train_batch=False, base_optimizers=[None, None], empty_tag=False, @@ -245,7 +244,7 @@ def checkpoint_correctness_verification(config_dict, load_module_only=load_module_only) if load_optimizer_states: - compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16) + compare_optimizer_states(trained_model, loaded_model, hidden_dim, dtype == torch.float16) if load_lr_scheduler_states: compare_lr_scheduler_states(trained_model, loaded_model) diff --git a/tests/unit/checkpoint/test_latest_checkpoint.py b/tests/unit/checkpoint/test_latest_checkpoint.py index 5abb78785c01..5d795c4dadcf 100644 --- a/tests/unit/checkpoint/test_latest_checkpoint.py +++ b/tests/unit/checkpoint/test_latest_checkpoint.py @@ -38,7 +38,6 @@ def test_existing_latest(self, tmpdir): tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, - fp16=False, empty_tag=True, dtype=torch.float) diff --git a/tests/unit/checkpoint/test_moe_checkpoint.py b/tests/unit/checkpoint/test_moe_checkpoint.py index 0706b7327ce8..36efe2a69002 100644 --- a/tests/unit/checkpoint/test_moe_checkpoint.py +++ b/tests/unit/checkpoint/test_moe_checkpoint.py @@ -33,10 +33,10 @@ def test_checkpoint_moe(self, tmpdir, ep_size): tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, - fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, - seq_dataloader=True) + seq_dataloader=True, + dtype=torch.float16) @pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)]) def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): @@ -77,7 +77,7 @@ def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): tmpdir=tmpdir, load_optimizer_states=load_optim_states, load_lr_scheduler_states=False, - fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, - seq_dataloader=True) + seq_dataloader=True, + dtype=torch.float16) diff --git a/tests/unit/checkpoint/test_other_optimizer.py b/tests/unit/checkpoint/test_other_optimizer.py index ebdf303c2ce2..bcff7f5e3072 100644 --- a/tests/unit/checkpoint/test_other_optimizer.py +++ b/tests/unit/checkpoint/test_other_optimizer.py @@ -135,5 +135,4 @@ def test_checkpoint_fp32_optimizer(self, tmpdir): models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - fp16=False, dtype=torch.float32) diff --git a/tests/unit/checkpoint/test_pipeline.py b/tests/unit/checkpoint/test_pipeline.py index 9686780af42d..c6c228ccada7 100644 --- a/tests/unit/checkpoint/test_pipeline.py +++ b/tests/unit/checkpoint/test_pipeline.py @@ -58,7 +58,6 @@ def test_checkpoint_pipe_engine(self, zero_stage, tmpdir): models=models, hidden_dim=models[0].hidden_dim, tmpdir=tmpdir, - fp16=config_dict['fp16']['enabled'], load_optimizer_states=True, load_lr_scheduler_states=True, train_batch=True,