Skip to content

Commit

Permalink
fix test_inference_config UT error
Browse files Browse the repository at this point in the history
  • Loading branch information
delock committed Jan 3, 2024
1 parent 1596224 commit a72beea
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/unit/inference/test_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestInferenceConfig(DistributedTest):
world_size = 1

def test_overlap_kwargs(self):
config = {"replace_with_kernel_inject": True}
config = {"replace_with_kernel_inject": True, "dtype": torch.float32}
kwargs = {"replace_with_kernel_inject": True}

engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
Expand All @@ -37,7 +37,7 @@ def test_kwargs_and_config(self):
assert engine._config.dtype == kwargs["dtype"]

def test_json_config(self, tmpdir):
config = {"replace_with_kernel_inject": True}
config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"}
config_json = create_config_from_dict(tmpdir, config)

engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)
Expand Down

0 comments on commit a72beea

Please sign in to comment.