diff --git a/tests/unit/inference/test_inference_config.py b/tests/unit/inference/test_inference_config.py index 375563abf65b..39d62d17372c 100644 --- a/tests/unit/inference/test_inference_config.py +++ b/tests/unit/inference/test_inference_config.py @@ -15,7 +15,7 @@ class TestInferenceConfig(DistributedTest): world_size = 1 def test_overlap_kwargs(self): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": torch.float32} kwargs = {"replace_with_kernel_inject": True} engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs) @@ -37,7 +37,7 @@ def test_kwargs_and_config(self): assert engine._config.dtype == kwargs["dtype"] def test_json_config(self, tmpdir): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"} config_json = create_config_from_dict(tmpdir, config) engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)