Skip to content

Commit

Permalink
Dtype support check for accelerator in UTs (#6360)
Browse files Browse the repository at this point in the history
Check if the dtype is supported by the accelarator if not then skip

---------

Co-authored-by: Shaik Raza Sikander <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
4 people authored Aug 28, 2024
1 parent 1041c8a commit b5cf30a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
3 changes: 3 additions & 0 deletions tests/unit/inference/test_checkpoint_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from huggingface_hub import snapshot_download
from transformers.utils import is_offline_mode
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
Expand Down Expand Up @@ -44,6 +45,8 @@ def model_name(request):

@pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"])
def dtype(request):
if request.param not in get_accelerator().supported_dtypes():
pytest.skip(f"{request.param} not supported by {get_accelerator().device_name()}.")
return request.param


Expand Down
3 changes: 3 additions & 0 deletions tests/unit/inference/test_model_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


@pytest.mark.inference
@pytest.mark.parametrize("use_cuda_events", [True, False])
Expand Down
7 changes: 1 addition & 6 deletions tests/unit/ops/transformer/inference/inference_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ def get_tolerances():
def get_dtypes():
global DTYPES
if DTYPES is None:
DTYPES = [torch.float16, torch.float32]
try:
if get_accelerator().is_bf16_supported():
DTYPES.append(torch.bfloat16)
except (AssertionError, AttributeError):
pass
DTYPES = get_accelerator().supported_dtypes()
return DTYPES


Expand Down

0 comments on commit b5cf30a

Please sign in to comment.