From a72beea53a8317d551fce74e03548c8fd76c8f0a Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 3 Jan 2024 03:14:29 +0000 Subject: [PATCH] fix test_inference_config UT error --- tests/unit/inference/test_inference_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/inference/test_inference_config.py b/tests/unit/inference/test_inference_config.py index 375563abf65b..39d62d17372c 100644 --- a/tests/unit/inference/test_inference_config.py +++ b/tests/unit/inference/test_inference_config.py @@ -15,7 +15,7 @@ class TestInferenceConfig(DistributedTest): world_size = 1 def test_overlap_kwargs(self): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": torch.float32} kwargs = {"replace_with_kernel_inject": True} engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs) @@ -37,7 +37,7 @@ def test_kwargs_and_config(self): assert engine._config.dtype == kwargs["dtype"] def test_json_config(self, tmpdir): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"} config_json = create_config_from_dict(tmpdir, config) engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)