diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 05854babece5..ec9e9e94aeaf 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -6,6 +6,7 @@ from types import SimpleNamespace import torch +import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape import deepspeed.comm as dist @@ -126,6 +127,8 @@ def test_scattered_init_dist(self): assert dist.is_initialized() def test_scatter_halftype(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") setup_serial_env() with deepspeed.zero.Init():