From c08e69f21238f15bfe0e3779170fefa2f75d4c7e Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 13 Mar 2024 04:48:29 +0800 Subject: [PATCH] Make op builder detection adapt to accelerator change (#5206) This is an WIP PR that make op builder detection adapt to accelerator change. This is followup of https://github.com/microsoft/DeepSpeed/issues/5173 Currently, DeepSpeed generate `installed_ops` and `compatible_ops` at setup time. If the system change to a different accelerator at DeepSpeed launch time, these two list would contain incorrect information. This PR intend to solve this problem with more flexity ops detection. * For `installed_ops`, DeepSpeed should disable all installed ops if accelerator detected at setup time is different from launch time. * For `compatible_ops`, DeepSpeed should refresh the list for each launch to avoid impact of accelerator change. In the first step, nv-inference workflow is temporary change to emulate the scenario that the system is setup with CPU_Accelerator, then launch with CUDA_Accelerator. And CPU_Accelerator is modified to make Intel Extension for PyTorch and oneCCL binding for PyTorch not mandatory. Starting from here we can reconstruct installed_ops and compatible_ops to follow the design above. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .github/workflows/cpu-inference.yml | 23 +-- .github/workflows/cpu-torch-latest.yml | 4 + .github/workflows/nv-inference.yml | 5 +- .github/workflows/nv-pre-compile-ops.yml | 2 +- accelerator/cpu_accelerator.py | 28 +++- accelerator/real_accelerator.py | 24 ++- deepspeed/env_report.py | 5 +- deepspeed/git_version_info.py | 11 +- deepspeed/ops/__init__.py | 2 - .../runtime/zero/partition_parameters.py | 9 +- op_builder/all_ops.py | 1 + op_builder/builder.py | 5 +- op_builder/xpu/builder.py | 5 +- setup.py | 8 +- tests/unit/checkpoint/common.py | 11 +- .../unit/checkpoint/test_latest_checkpoint.py | 4 +- tests/unit/checkpoint/test_lr_scheduler.py | 19 ++- tests/unit/checkpoint/test_moe_checkpoint.py | 8 +- tests/unit/checkpoint/test_other_optimizer.py | 20 ++- tests/unit/checkpoint/test_pipeline.py | 4 +- tests/unit/checkpoint/test_zero_optimizer.py | 143 ++++++++--------- tests/unit/common.py | 10 ++ tests/unit/compression/test_dequantization.py | 9 +- tests/unit/elasticity/test_elastic.py | 4 +- tests/unit/launcher/test_user_args.py | 4 +- tests/unit/multi_output_model.py | 8 +- .../accelerators/test_accelerator_backward.py | 4 - .../test_activation_checkpointing.py | 4 + .../comm/test_coalesced_collectives.py | 8 + .../runtime/compile/test_compile_wrapper.py | 2 + .../unit/runtime/compile/test_compile_zero.py | 3 + .../unit/runtime/compile/test_load_config.py | 12 ++ .../half_precision/onebit/test_onebit.py | 33 ++++ .../unit/runtime/half_precision/test_bf16.py | 5 + .../half_precision/test_dynamic_loss_scale.py | 15 ++ .../unit/runtime/half_precision/test_fp16.py | 42 +++++ tests/unit/runtime/test_data_efficiency.py | 36 +++-- tests/unit/runtime/test_ds_config_dict.py | 41 ++++- tests/unit/runtime/test_ds_initialize.py | 17 +- tests/unit/runtime/test_multi_output_model.py | 21 +-- tests/unit/runtime/test_mup_optimizers.py | 8 +- tests/unit/runtime/test_pld.py | 15 +- .../zero/test_ignore_unused_parameters.py | 9 +- tests/unit/runtime/zero/test_zero.py | 149 ++++++++++-------- tests/unit/runtime/zero/test_zero_context.py | 17 +- .../runtime/zero/test_zero_context_return.py | 18 ++- .../runtime/zero/test_zero_leaf_module.py | 12 +- .../runtime/zero/test_zero_tensor_fragment.py | 18 ++- tests/unit/simple_model.py | 7 +- 49 files changed, 567 insertions(+), 305 deletions(-) diff --git a/.github/workflows/cpu-inference.yml b/.github/workflows/cpu-inference.yml index 3e09d3cc1e49..38dd9bd3efef 100644 --- a/.github/workflows/cpu-inference.yml +++ b/.github/workflows/cpu-inference.yml @@ -47,42 +47,26 @@ jobs: - name: Detect instruction sets on instance run: | lscpu - pip install cmake - git clone https://github.com/intel/intel-extension-for-pytorch - cd intel-extension-for-pytorch/tests/cpu/isa - cmake . - make - ./cpu_features - name: Install numactl run: | sudo apt-get install -y numactl - - name: Install oneCCL Bindings for PyTorch + - name: Install dependencies run: | pip install torch - python -m pip install intel_extension_for_pytorch - # the curl line is for troubleshooting - curl -L https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ - python -m pip install oneccl_bind_pt --index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ - pip install py-cpuinfo # check installed version pip list |grep \\\ - pip list |grep intel-extension-for-pytorch - pip list |grep oneccl-bind-pt - name: Install oneCCL run: | + pip install cmake git clone https://github.com/oneapi-src/oneCCL cd oneCCL mkdir build cd build cmake .. - make - make install - #source ./_install/env/setvars.sh - # test whether oneCCL is correctly installed - #mpirun -n 2 ./examples/benchmark/benchmark + make -j install - name: Install transformers run: | @@ -103,7 +87,6 @@ jobs: source oneCCL/build/_install/env/setvars.sh export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 # check whether the environment is properly setup - python -c "import torch;import intel_extension_for_pytorch as ipex;import oneccl_bindings_for_pytorch;print('done')" python -c "import deepspeed;from deepspeed.accelerator import get_accelerator;print(get_accelerator().device_name());print(get_accelerator().is_available())" - name: Unit tests diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index ba4906db15c9..5096de931be4 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -27,6 +27,10 @@ jobs: - id: setup-venv uses: ./.github/workflows/setup-venv + - name: Install system packages + run: | + sudo apt-get install -y numactl pdsh + - name: Install pytorch run: | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index a24376c8973d..2b74e7e155df 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -46,7 +46,8 @@ jobs: - name: Install deepspeed run: | - pip install .[dev,1bit,autotuning,inf,triton] + DS_ACCELERATOR=cpu pip install .[dev,1bit,autotuning,inf] + #pip install .[dev,1bit,autotuning,inf,triton] ds_report - name: Python environment @@ -60,3 +61,5 @@ jobs: #pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="2.1" --cuda_ver="11.8" pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="2.1" --cuda_ver="11.8" pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="2.1" --cuda_ver="11.8" + # run ds_report again to check updated op list + ds_report diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index 6440de1a81ba..18db40380577 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -36,7 +36,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . - name: DS Report run: | ds_report diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index b1aba75b4c5e..870d3e91816e 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -4,9 +4,14 @@ # DeepSpeed Team import torch -from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator -import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore -import psutil +from .abstract_accelerator import DeepSpeedAccelerator + +try: + import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore + oneccl_imported_p = True +except ImportError as e: + oneccl_imported_p = False + import os @@ -15,8 +20,17 @@ class CPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cpu' - self._communication_backend_name = 'ccl' - self.max_mem = psutil.Process().memory_info().rss + if oneccl_imported_p: + self._communication_backend_name = 'ccl' + else: + # fallback to gloo if oneccl_binding_for_pytorch is not installed + self._communication_backend_name = 'gloo' + try: + import psutil + mem = psutil.Process().memory_info().rss + self.max_mem = mem + except ImportError as e: + self.max_mem = 0 def is_synchronized_device(self): return True @@ -115,12 +129,14 @@ def empty_cache(self): return def get_rss(self): + import psutil mem = psutil.Process().memory_info().rss if mem > self.max_mem: self.max_mem = mem return mem def reset_rss(self): + import psutil mem = psutil.Process().memory_info().rss self.max_mem = mem return mem @@ -166,9 +182,11 @@ def max_memory_reserved(self, device_index=None): return self.max_mem def total_memory(self, device_index=None): + import psutil return psutil.virtual_memory().total def available_memory(self, device_index=None): + import psutil return psutil.virtual_memory().available # Misc diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 1090a61681d9..037162e867ec 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -73,11 +73,7 @@ def get_accelerator(): f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system." ) elif accelerator_name == "cpu": - try: - import intel_extension_for_pytorch # noqa: F401 # type: ignore - except ImportError as e: - raise ValueError( - f"CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.") + pass elif accelerator_name == "npu": try: import torch_npu # noqa: F401 # type: ignore @@ -154,7 +150,23 @@ def get_accelerator(): except ImportError as e: pass if accelerator_name is None: - accelerator_name = "cuda" + # borrow this log from PR#5084 + try: + import torch + + # Determine if we are on a GPU or x86 CPU with torch. + if torch.cuda.is_available(): #ignore-cuda + accelerator_name = "cuda" + else: + if accel_logger is not None: + accel_logger.warn( + "Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it." + ) + accelerator_name = "cpu" + except (RuntimeError, ImportError) as e: + # TODO need a more decent way to detect which accelerator to use, consider using nvidia-smi command for detection + accelerator_name = "cuda" + pass ds_set_method = "auto detect" diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py index 912262453748..85a2f9b2b6c0 100644 --- a/deepspeed/env_report.py +++ b/deepspeed/env_report.py @@ -9,7 +9,7 @@ import subprocess import argparse from .ops.op_builder.all_ops import ALL_OPS -from .git_version_info import installed_ops, torch_info +from .git_version_info import installed_ops, torch_info, accelerator_name from deepspeed.accelerator import get_accelerator GREEN = '\033[92m' @@ -51,7 +51,8 @@ def op_report(verbose=True): for op_name, builder in ALL_OPS.items(): dots = "." * (max_dots - len(op_name)) is_compatible = OKAY if builder.is_compatible(verbose) else no - is_installed = installed if installed_ops.get(op_name, False) else no + is_installed = installed if installed_ops.get(op_name, + False) and accelerator_name == get_accelerator()._name else no dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len)) print(op_name, dots, is_installed, dots2, is_compatible) print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py index 635842c760ea..70c536d2f78e 100644 --- a/deepspeed/git_version_info.py +++ b/deepspeed/git_version_info.py @@ -18,5 +18,14 @@ from .ops.op_builder.all_ops import ALL_OPS installed_ops = dict.fromkeys(ALL_OPS.keys(), False) - compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) + accelerator_name = "" torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"} + +# compatible_ops list is recreated for each launch +from .ops.op_builder.all_ops import ALL_OPS + +compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + compatible_ops[op_name] = op_compatible + compatible_ops["deepspeed_not_implemented"] = False diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index ba1c9c1fd9f0..7ea5ce5af19e 100755 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -7,8 +7,6 @@ from . import adagrad from . import lamb from . import lion -#from ..git_version_info_installed import installed_ops as __installed_ops__ -#if __installed_ops__['sparse_attn']: from . import sparse_attention from . import transformer diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 142259c1b7df..c8099791f882 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -56,7 +56,8 @@ def __init__(self, param: Parameter) -> None: self.__param = param def wait(self) -> None: - get_accelerator().current_stream().synchronize() + if not get_accelerator().is_synchronized_device(): + get_accelerator().current_stream().synchronize() self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -81,7 +82,8 @@ def wait(self) -> None: if self.__complete: return - get_accelerator().current_stream().synchronize() + if not get_accelerator().is_synchronized_device(): + get_accelerator().current_stream().synchronize() for param in self.__params: assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" param.ds_status = ZeroParamStatus.AVAILABLE @@ -363,7 +365,8 @@ def _set_dtype(self, ds_config, dtype): else: self.dtype = torch.float else: - self.dtype = dtype or torch.half + self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported( + ) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32 def patch_init_and_builtins(self): diff --git a/op_builder/all_ops.py b/op_builder/all_ops.py index 9c41f35eaf1b..ff11ca180072 100644 --- a/op_builder/all_ops.py +++ b/op_builder/all_ops.py @@ -30,3 +30,4 @@ __op_builders__.append(builder) ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} +accelerator_name = get_accelerator()._name diff --git a/op_builder/builder.py b/op_builder/builder.py index dd77f967cc60..8dc825c7926d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -464,8 +464,9 @@ def load(self, verbose=True): if self.name in __class__._loaded_ops: return __class__._loaded_ops[self.name] - from deepspeed.git_version_info import installed_ops, torch_info - if installed_ops.get(self.name, False): + from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name + from deepspeed.accelerator import get_accelerator + if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name: # Ensure the op we're about to load was compiled with the same # torch/cuda versions we are currently using at runtime. self.validate_torch_version(torch_info) diff --git a/op_builder/xpu/builder.py b/op_builder/xpu/builder.py index 459dcce6bfae..81b15f197f43 100644 --- a/op_builder/xpu/builder.py +++ b/op_builder/xpu/builder.py @@ -74,8 +74,9 @@ def fixed_aotflags(self): ] def load(self, verbose=True): - from deepspeed.git_version_info import installed_ops, torch_info # noqa: F401 - if installed_ops.get(self.name, False): + from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name # noqa: F401 + from deepspeed.accelerator import get_accelerator + if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name: return importlib.import_module(self.absolute_name()) else: return self.jit_load(verbose) diff --git a/setup.py b/setup.py index 25b741af9440..f1367b850e02 100755 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ 'Please visit https://pytorch.org/ to see how to properly install torch on your system.') from op_builder import get_default_compute_capabilities, OpBuilder -from op_builder.all_ops import ALL_OPS +from op_builder.all_ops import ALL_OPS, accelerator_name from op_builder.builder import installed_cuda_version # Fetch rocm state. @@ -168,12 +168,9 @@ def op_enabled(op_name): return int(get_env_if_set(env_var, BUILD_OP_DEFAULT)) -compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) install_ops = dict.fromkeys(ALL_OPS.keys(), False) for op_name, builder in ALL_OPS.items(): op_compatible = builder.is_compatible() - compatible_ops[op_name] = op_compatible - compatible_ops["deepspeed_not_implemented"] = False # If op is requested but not available, throw an error. if op_enabled(op_name) and not op_compatible: @@ -280,11 +277,10 @@ def create_dir_symlink(src, dest): fd.write(f"git_hash='{git_hash}'\n") fd.write(f"git_branch='{git_branch}'\n") fd.write(f"installed_ops={install_ops}\n") - fd.write(f"compatible_ops={compatible_ops}\n") + fd.write(f"accelerator_name='{accelerator_name}'\n") fd.write(f"torch_info={torch_info}\n") print(f'install_requires={install_requires}') -print(f'compatible_ops={compatible_ops}') print(f'ext_modules={ext_modules}') # Parse README.md to make long_description for PyPI page. diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index 7442e51bad5d..08fa1eb671bd 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -14,6 +14,7 @@ from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from unit.common import preferred_dtype from unit.simple_model import * from unittest.mock import MagicMock, patch @@ -163,13 +164,15 @@ def checkpoint_correctness_verification(config_dict, tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False, - fp16=True, train_batch=False, base_optimizers=[None, None], empty_tag=False, seq_dataloader=False, - load_module_only=False): - dtype = torch.half if fp16 else torch.float32 + load_module_only=False, + dtype=None): + if dtype == None: + dtype = preferred_dtype() + ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0]) if seq_dataloader: @@ -241,7 +244,7 @@ def checkpoint_correctness_verification(config_dict, load_module_only=load_module_only) if load_optimizer_states: - compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16) + compare_optimizer_states(trained_model, loaded_model, hidden_dim, dtype == torch.float16) if load_lr_scheduler_states: compare_lr_scheduler_states(trained_model, loaded_model) diff --git a/tests/unit/checkpoint/test_latest_checkpoint.py b/tests/unit/checkpoint/test_latest_checkpoint.py index 41ce2278680f..5d795c4dadcf 100644 --- a/tests/unit/checkpoint/test_latest_checkpoint.py +++ b/tests/unit/checkpoint/test_latest_checkpoint.py @@ -38,8 +38,8 @@ def test_existing_latest(self, tmpdir): tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, - fp16=False, - empty_tag=True) + empty_tag=True, + dtype=torch.float) def test_missing_latest(self, tmpdir): config_dict = { diff --git a/tests/unit/checkpoint/test_lr_scheduler.py b/tests/unit/checkpoint/test_lr_scheduler.py index c4c6773cd474..4891b4f6fa9b 100644 --- a/tests/unit/checkpoint/test_lr_scheduler.py +++ b/tests/unit/checkpoint/test_lr_scheduler.py @@ -5,6 +5,7 @@ import deepspeed from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest from unit.simple_model import * @@ -22,6 +23,8 @@ class TestLRSchedulerCheckpoint(DistributedTest): def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") + if get_accelerator().device_name() == 'cpu': + pytest.skip("CPU accelerator does not support this test.") config_dict = { "train_batch_size": 2, @@ -35,9 +38,6 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload @@ -51,6 +51,10 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): } } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: @@ -71,6 +75,8 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") + if get_accelerator().device_name() == 'cpu': + pytest.skip("CPU accelerator does not support this test.") config_dict = { "train_batch_size": 2, @@ -81,9 +87,6 @@ def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): "lr": 1e-5 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload @@ -97,6 +100,10 @@ def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): } }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: diff --git a/tests/unit/checkpoint/test_moe_checkpoint.py b/tests/unit/checkpoint/test_moe_checkpoint.py index 0706b7327ce8..36efe2a69002 100644 --- a/tests/unit/checkpoint/test_moe_checkpoint.py +++ b/tests/unit/checkpoint/test_moe_checkpoint.py @@ -33,10 +33,10 @@ def test_checkpoint_moe(self, tmpdir, ep_size): tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, - fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, - seq_dataloader=True) + seq_dataloader=True, + dtype=torch.float16) @pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)]) def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): @@ -77,7 +77,7 @@ def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): tmpdir=tmpdir, load_optimizer_states=load_optim_states, load_lr_scheduler_states=False, - fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, - seq_dataloader=True) + seq_dataloader=True, + dtype=torch.float16) diff --git a/tests/unit/checkpoint/test_other_optimizer.py b/tests/unit/checkpoint/test_other_optimizer.py index 9cb8c4286880..bcff7f5e3072 100644 --- a/tests/unit/checkpoint/test_other_optimizer.py +++ b/tests/unit/checkpoint/test_other_optimizer.py @@ -19,6 +19,8 @@ class TestOtherOptimizerCheckpoint(DistributedTest): @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") def test_checkpoint_unfused_optimizer(self, tmpdir): + #if not get_accelerator().is_fp16_supported(): + # pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -29,9 +31,6 @@ def test_checkpoint_unfused_optimizer(self, tmpdir): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True - }, "scheduler": { "type": "OneCycle", "params": { @@ -49,6 +48,10 @@ def test_checkpoint_unfused_optimizer(self, tmpdir): } } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["bf16"] = {"enabled": True} args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -69,6 +72,8 @@ def test_checkpoint_unfused_optimizer(self, tmpdir): load_optimizer_states=False) def test_checkpoint_fused_optimizer(self, tmpdir): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -81,10 +86,11 @@ def test_checkpoint_fused_optimizer(self, tmpdir): "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -129,4 +135,4 @@ def test_checkpoint_fp32_optimizer(self, tmpdir): models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - fp16=False) + dtype=torch.float32) diff --git a/tests/unit/checkpoint/test_pipeline.py b/tests/unit/checkpoint/test_pipeline.py index 99f1ba2ec433..c6c228ccada7 100644 --- a/tests/unit/checkpoint/test_pipeline.py +++ b/tests/unit/checkpoint/test_pipeline.py @@ -58,10 +58,10 @@ def test_checkpoint_pipe_engine(self, zero_stage, tmpdir): models=models, hidden_dim=models[0].hidden_dim, tmpdir=tmpdir, - fp16=config_dict['fp16']['enabled'], load_optimizer_states=True, load_lr_scheduler_states=True, - train_batch=True) + train_batch=True, + dtype=torch.float16 if zero_stage > 0 else torch.float32) @pytest.mark.parametrize( "base_topo,test_topo", diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index 0b9efb3ec462..aebad4227358 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -28,15 +28,15 @@ def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, "pipeline_loading_checkpoint": True, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 with deepspeed.zero.Init(): @@ -64,16 +64,16 @@ def test_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_op "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "wall_clock_breakdown": True, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: @@ -104,14 +104,15 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, ada "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: @@ -134,11 +135,11 @@ def test_hybrid_optimizer_state(self, tmpdir, zero_stage): "stage": zero_stage }, "zero_allow_untested_optimizer": True, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)] optimizers = [HybridStateOptimizer(model.parameters()) for model in models] @@ -152,19 +153,21 @@ def test_hybrid_optimizer_state(self, tmpdir, zero_stage): @pytest.mark.parametrize('zero_stage', [0, 1, 2, 3]) def test_load_module_only(self, tmpdir, zero_stage): + if zero_stage == 0 and get_accelerator().device_name() == "cpu": + pytest.skip("CPU Accelerator does not support this test") config_dict = { "train_batch_size": 2, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: @@ -185,15 +188,15 @@ def run(self, class_tmpdir, elastic_save, load_optim): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 2, "elastic_checkpoint": elastic_save } } + if get_accelerator().is_fp16_supported(): + ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + ds_config["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -221,15 +224,15 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 2, "elastic_checkpoint": elastic_save } } + if get_accelerator().is_fp16_supported(): + ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + ds_config["bf16"] = {"enabled": True} hidden_dim = 10 # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to @@ -274,15 +277,15 @@ def test_elastic_checkpoint_change_dp(self, ws4_model_checkpoint, class_tmpdir, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 2, "elastic_checkpoint": elastic_load } } + if get_accelerator().is_fp16_supported(): + ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + ds_config["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -305,14 +308,14 @@ def test_immediate_save_load(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -325,30 +328,27 @@ def test_immediate_save_load(self, tmpdir, zero_stage): @pytest.mark.parametrize('zero_stage', [0, 1, 2, 3]) def test_load_immediate_save(self, tmpdir, zero_stage): + if zero_stage == 0 and get_accelerator().device_name() == "cpu": + pytest.skip("CPU Accelerator does not support this test") config_dict = { "train_batch_size": 4, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) # 1. pretrain a model and save it - dtype = torch.half ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) - data_loader = random_dataloader(model=ds_model, - total_samples=1, - hidden_dim=hidden_dim, - device=ds_model.device, - dtype=dtype) + data_loader = random_dataloader(model=ds_model, total_samples=1, hidden_dim=hidden_dim, device=ds_model.device) for _, batch in enumerate(data_loader): loss = ds_model(batch[0], batch[1]) ds_model.backward(loss) @@ -371,10 +371,6 @@ def test_save_before_accum_grad_is_done(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, "stage3_gather_fp16_weights_on_model_save": True, @@ -383,6 +379,10 @@ def test_save_before_accum_grad_is_done(self, tmpdir, zero_stage): "train_micro_batch_size_per_gpu": 1, "train_batch_size": 4, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -391,11 +391,7 @@ def test_save_before_accum_grad_is_done(self, tmpdir, zero_stage): # So we config grad_accum=2 and step only once and save_16bit_model ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) - data_loader = random_dataloader(model=ds_model, - total_samples=2, - hidden_dim=hidden_dim, - device=ds_model.device, - dtype=torch.half) + data_loader = random_dataloader(model=ds_model, total_samples=2, hidden_dim=hidden_dim, device=ds_model.device) batch = next(iter(data_loader)) loss = ds_model(batch[0], batch[1]) @@ -429,15 +425,15 @@ def test_load_optimizer_state(self, tmpdir, zero_stage): "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "wall_clock_breakdown": True, "zero_optimization": { "stage": zero_stage } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 with deepspeed.zero.Init(enabled=zero_stage == 3): @@ -460,13 +456,14 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage): "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 with deepspeed.zero.Init(enabled=zero_stage == 3): @@ -481,14 +478,14 @@ def test_load_module_only(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 with deepspeed.zero.Init(enabled=zero_stage == 3): @@ -504,14 +501,14 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleFrozenModel(hidden_dim, empty_grad=False) @@ -552,14 +549,14 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleFrozenModel(hidden_dim, empty_grad=False) diff --git a/tests/unit/common.py b/tests/unit/common.py index 76bebf6b725a..c002aa372c8c 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -441,3 +441,13 @@ def _get_current_test_func(self, request): def get_test_path(filename): curr_path = Path(__file__).parent return str(curr_path.joinpath(filename)) + + +# fp16 > bf16 > fp32 +def preferred_dtype(): + if get_accelerator().is_fp16_supported(): + return torch.float16 + elif get_accelerator().is_bf16_supported(): + return torch.bfloat16 + else: + return torch.float32 diff --git a/tests/unit/compression/test_dequantization.py b/tests/unit/compression/test_dequantization.py index 692f4cef97d7..8446904754b3 100644 --- a/tests/unit/compression/test_dequantization.py +++ b/tests/unit/compression/test_dequantization.py @@ -7,8 +7,9 @@ import os import torch +import pytest from unit.common import DistributedTest -from deepspeed.ops.op_builder import InferenceBuilder +import deepspeed from deepspeed.accelerator import get_accelerator @@ -18,7 +19,11 @@ def init(self): local_rank = int(os.getenv("LOCAL_RANK", "0")) self.device = torch.device(get_accelerator().device_name(local_rank)) - self.dequantize_func = InferenceBuilder().load().dequantize_fp16 + from deepspeed.ops.op_builder import InferenceBuilder + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("InferenceBuilder is not implemented") + else: + self.dequantize_func = InferenceBuilder().load().dequantize_fp16 def run_dequantize_test(self, M, N, num_groups): weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device) diff --git a/tests/unit/elasticity/test_elastic.py b/tests/unit/elasticity/test_elastic.py index a49ec595a420..63633a51914b 100644 --- a/tests/unit/elasticity/test_elastic.py +++ b/tests/unit/elasticity/test_elastic.py @@ -9,7 +9,7 @@ from deepspeed.git_version_info import version as ds_version import os from unit.simple_model import SimpleModel -from deepspeed.ops.op_builder import FusedAdamBuilder +from deepspeed.ops.op_builder import FusedAdamBuilder, FusedLambBuilder if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]: pytest.skip("This op had not been implemented on this system.", allow_module_level=True) @@ -183,6 +183,8 @@ class TestNonElasticBatchParamsWithOverride(DistributedTest): world_size = 2 def test(self): + if not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) config_dict = { "train_batch_size": 2, "steps_per_print": 1, diff --git a/tests/unit/launcher/test_user_args.py b/tests/unit/launcher/test_user_args.py index 99afd0f2cfa7..b86be4dfe74c 100644 --- a/tests/unit/launcher/test_user_args.py +++ b/tests/unit/launcher/test_user_args.py @@ -43,7 +43,9 @@ def cmd(user_script_fp, prompt, multi_node): '''I'm going to tell them "DeepSpeed is the best"''' ]) @pytest.mark.parametrize("multi_node", [True, False]) -def test_user_args(cmd): +def test_user_args(cmd, multi_node): + if multi_node and get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" diff --git a/tests/unit/multi_output_model.py b/tests/unit/multi_output_model.py index e84215fb4e95..d7a5f9a46b97 100644 --- a/tests/unit/multi_output_model.py +++ b/tests/unit/multi_output_model.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from .common import preferred_dtype class MultiOutputModel(torch.nn.Module): @@ -28,8 +29,11 @@ def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, ta batch_size = model.train_micro_batch_size_per_gpu() train_data = [ - torch.full(size=(total_samples, hidden_dim), fill_value=x, device=device, dtype=torch.half, requires_grad=True) - for x in inputs + torch.full(size=(total_samples, hidden_dim), + fill_value=x, + device=device, + dtype=preferred_dtype(), + requires_grad=True) for x in inputs ] train_label = [torch.empty(total_samples, device=device, dtype=torch.long).fill_(y) for y in targets] diff --git a/tests/unit/ops/accelerators/test_accelerator_backward.py b/tests/unit/ops/accelerators/test_accelerator_backward.py index 43f7b471e2ae..48e5fbbe7475 100644 --- a/tests/unit/ops/accelerators/test_accelerator_backward.py +++ b/tests/unit/ops/accelerators/test_accelerator_backward.py @@ -16,10 +16,6 @@ from unit.modelingpreln import BertEncoder as BertEncoderPreln from unit.common import DistributedTest, is_rocm_pytorch -#if not deepspeed.ops.__installed_ops__['transformer']: -#pytest.skip( -# "transformer kernels are temporarily disabled because of unexplained failures", -# allow_module_level=True) if torch.half not in get_accelerator().supported_dtypes(): pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index 0232457a4f9c..22a61003b31e 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -62,6 +62,8 @@ def _match_outputs(ref, tgt): def _test_activation_checkpoint(module, *inputs): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") # Move to device module.to(get_accelerator().device_name()) @@ -82,6 +84,8 @@ def _test_activation_checkpoint(module, *inputs): def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") # Move to device module.to(get_accelerator().device_name()) diff --git a/tests/unit/runtime/comm/test_coalesced_collectives.py b/tests/unit/runtime/comm/test_coalesced_collectives.py index d9ac79619bd3..17b2ffbb9d29 100644 --- a/tests/unit/runtime/comm/test_coalesced_collectives.py +++ b/tests/unit/runtime/comm/test_coalesced_collectives.py @@ -7,9 +7,11 @@ """ import torch +import deepspeed import deepspeed.comm as dist from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce from deepspeed.accelerator import get_accelerator +import pytest from unit.common import DistributedTest @@ -68,6 +70,9 @@ class TestAllToAllQuantReduceFallback(DistributedTest): def test_1d_tensor(self): # case 1: 1D tensor input = torch.zeros((10, ), dtype=torch.half, device=get_accelerator().current_device_name()) + from deepspeed.ops.op_builder import QuantizerBuilder + if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("QuantizerBuilder is not implemented") output = all_to_all_quant_reduce([input], {})[0] if dist.get_rank() == 0: @@ -80,6 +85,9 @@ def test_1d_tensor(self): def test_non_divisible(self): # case 2: tensor size not divisible by global_world_size input = torch.zeros((7, 7), dtype=torch.half, device=get_accelerator().current_device_name()) + from deepspeed.ops.op_builder import QuantizerBuilder + if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("QuantizerBuilder is not implemented") output = all_to_all_quant_reduce([input], {})[0] if dist.get_rank() == 0: diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py index 98a7c28c6a28..477b2fe2cc1b 100644 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -72,6 +72,8 @@ def _run_model(self, engine): @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") def test_custom_function(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") test_value = 10 engine = self._init_engine(base_config, test_value) diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 910f32db1c96..b3ab91dc4b4c 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -8,6 +8,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.utils import required_torch_version +from deepspeed.accelerator import get_accelerator from unit.runtime.compile.util import compare_loss from unit.common import DistributedTest @@ -29,6 +30,8 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): pytest.skip( " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") if offload_device == OffloadDeviceEnum.nvme: if zero_stage != 3: diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py index 5f1c01b86852..2c0511c31480 100644 --- a/tests/unit/runtime/compile/test_load_config.py +++ b/tests/unit/runtime/compile/test_load_config.py @@ -74,12 +74,16 @@ def _run_model(self, engine): @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") def test_compile(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") engine = self._init_engine(base_config) self._run_model(engine) assert engine.is_compiled @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") def test_custom_backend(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") global custom_backend_called custom_backend_called = False @@ -89,12 +93,16 @@ def test_custom_backend(self, base_config): assert custom_backend_called def test_compile_disabled(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") base_config["compile"]["enabled"] = False engine = self._init_engine(base_config) self._run_model(engine) @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") def test_compile_kwargs(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") base_config["compile"]["kwargs"] = {"mode": "default"} engine = self._init_engine(base_config) self._run_model(engine) @@ -102,6 +110,8 @@ def test_compile_kwargs(self, base_config): @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") def test_set_compile_kwargs(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") engine = self._init_engine(base_config) engine.set_torch_compile_kwargs({"mode": "default"}) self._run_model(engine) @@ -109,6 +119,8 @@ def test_set_compile_kwargs(self, base_config): @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") def test_set_compiler_fn(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") global custom_compler_fn_called custom_compler_fn_called = False diff --git a/tests/unit/runtime/half_precision/onebit/test_onebit.py b/tests/unit/runtime/half_precision/onebit/test_onebit.py index ba795a853be0..71b49b7723b6 100644 --- a/tests/unit/runtime/half_precision/onebit/test_onebit.py +++ b/tests/unit/runtime/half_precision/onebit/test_onebit.py @@ -39,6 +39,9 @@ class TestOneBitAdamBasic(DistributedTest): world_size = 2 def test(self, dtype): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -80,6 +83,8 @@ class TestOneBitAdamExpAvgMask(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -144,6 +149,8 @@ class TestOneBitAdamCheckpointing(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -293,6 +300,8 @@ def test(self, tmpdir): assert optimizer_3.optimizer.adam_freeze_key is False def test_overflow(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -343,6 +352,8 @@ class TestOneBitAdamFP16Pipeline(DistributedTest): world_size = 4 def test(self, topo_config): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 4, "grandient_accumulation_steps": 1, @@ -388,6 +399,8 @@ class TestZeroOneAdamBasic(DistributedTest): world_size = 2 def test(self, dtype): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -432,6 +445,8 @@ class TestZeroOneAdamExpAvgMask(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -499,6 +514,8 @@ class TestZeroOneAdamCheckpointing(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -647,6 +664,8 @@ def test(self, tmpdir): assert "server_error" not in v, f"Incorrect server error" def test_overflow(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -700,6 +719,8 @@ class TestZeroOneAdamFP16Pipeline(DistributedTest): world_size = 4 def test(self, topo_config): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 4, "grandient_accumulation_steps": 1, @@ -748,6 +769,8 @@ class TestOneBitLambBasic(DistributedTest): world_size = 2 def test(self, dtype): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -795,6 +818,8 @@ class TestOneBitLampExpAvgMask(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -864,6 +889,8 @@ class TestOneBitLambCheckpointing(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -1030,6 +1057,8 @@ def test(self, tmpdir): assert optimizer_3.optimizer.lamb_freeze_key is False def test_overflow(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -1086,6 +1115,8 @@ class TestOneBitLambFP16Pipeline(DistributedTest): world_size = 4 def test(self, topo_config): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 4, "grandient_accumulation_steps": 1, @@ -1131,6 +1162,8 @@ class TestCompressedAllReduceBasic(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") from deepspeed.runtime.comm.nccl import NcclBackend size = dist.get_world_size() diff --git a/tests/unit/runtime/half_precision/test_bf16.py b/tests/unit/runtime/half_precision/test_bf16.py index 3f551fb0fd4a..d42a4b62cd10 100644 --- a/tests/unit/runtime/half_precision/test_bf16.py +++ b/tests/unit/runtime/half_precision/test_bf16.py @@ -12,6 +12,7 @@ from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader from unit.util import bf16_required_version_check from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator class TestAdamBF16ZeroOneCycleCompatibility(DistributedTest): @@ -299,6 +300,10 @@ def test(self, comp_type, comm_type): " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) + if comp_type == torch.float16 or comm_type == torch.float16: + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"} config_dict = { diff --git a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py index 2a58fd6b4a57..f350e08e68a7 100644 --- a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py +++ b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py @@ -5,6 +5,8 @@ import torch import deepspeed +from deepspeed.accelerator import get_accelerator +import pytest import numpy as np from unit.common import DistributedTest from unit.simple_model import SimpleModel @@ -22,6 +24,9 @@ class TestFused(DistributedTest): world_size = 1 def test_no_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -57,6 +62,8 @@ def test_no_overflow(self): expected_loss_scale *= 2 def test_all_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -90,6 +97,8 @@ def test_all_overflow(self): assert optim.cur_iter == (i + 1) def test_some_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -147,6 +156,8 @@ class TestUnfused(DistributedTest): world_size = 1 def test_no_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -181,6 +192,8 @@ def test_no_overflow(self): expected_loss_scale *= 2 def test_all_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -217,6 +230,8 @@ def test_all_overflow(self): assert optim.cur_iter == (i + 1) def test_some_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index 3d5e18b46502..e54fe352bf5b 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -26,6 +26,8 @@ class TestLambFP32GradClip(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -56,6 +58,8 @@ class TestLambFP16(DistributedTest): world_size = 2 def test__basic(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -81,6 +85,8 @@ def test__basic(self): model.step() def test_empty_grad(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -143,6 +149,8 @@ class TestAdamwFP16Basic(DistributedTest): world_size = 1 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}} hidden_dim = 10 @@ -160,6 +168,8 @@ class TestFP16OptimizerForMoE(DistributedTest): world_size = 2 def test_unfused_gradnorm(self, monkeypatch): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") @@ -188,6 +198,8 @@ def mock_unscale_and_clip_grads(total_norm, apply_scale=True): engine.step() def test_fused_gradnorm(self, monkeypatch): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") @@ -218,6 +230,8 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True) @pytest.mark.parametrize("fused_lamb_legacy", [(False), (True)]) def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") @@ -262,6 +276,8 @@ class TestAdamwFP16EmptyGrad(DistributedTest): world_size = 1 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}} hidden_dim = 10 @@ -281,6 +297,8 @@ class TestAdamFP16ZeroOneCycleCompatibility(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -332,6 +350,8 @@ class TestZeroStaticScale(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload, hidden_dim=4): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -375,6 +395,8 @@ class TestZeroAllowUntestedOptimizer(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -408,6 +430,8 @@ class TestZeroEmptyPartition(DistributedTest): world_size = 3 def test(self, zero_stage, use_cpu_offload): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -454,6 +478,8 @@ class TestAmp(DistributedTest): world_size = 2 def test_adam_basic(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = {"train_batch_size": 2, "steps_per_print": 1, "amp": {"enabled": True}} hidden_dim = 10 @@ -467,6 +493,8 @@ def test_adam_basic(self): model.step() def test_lamb_basic(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -492,6 +520,8 @@ def test_lamb_basic(self): model.step() def test_adam_O2(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -518,6 +548,8 @@ def test_adam_O2(self): model.step() def test_adam_O2_empty_grad(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -550,6 +582,8 @@ class TestZeroSupportedClientOptimizer(DistributedTest): world_size = 1 def test(self, zero_stage, optimizer_constructor): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -571,6 +605,8 @@ class TestZero2ReduceScatterOff(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -610,6 +646,8 @@ class TestFP16AdamTypes(DistributedTest): world_size = 1 def test(self, adam_type, torch_impl): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -642,6 +680,8 @@ class TestZero3LazyScatter(DistributedTest): world_size = 1 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -677,6 +717,8 @@ class TestZeroEmptyGrad(DistributedTest): world_size = 1 def test(self, stage): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, diff --git a/tests/unit/runtime/test_data_efficiency.py b/tests/unit/runtime/test_data_efficiency.py index b9bd9c3aa56e..87fb49aad830 100644 --- a/tests/unit/runtime/test_data_efficiency.py +++ b/tests/unit/runtime/test_data_efficiency.py @@ -7,6 +7,7 @@ import os import deepspeed from deepspeed.accelerator import get_accelerator +import pytest from unit.common import DistributedTest from unit.simple_model import Curriculum_SimpleModel, SimpleModel, random_dataloader, random_dataset @@ -53,6 +54,8 @@ class TestDataEfficiency(DistributedTest): world_size = 2 def test_curriculum_learning(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -64,11 +67,6 @@ def test_curriculum_learning(self): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, "data_efficiency": { "enabled": True, "seed": 1234, @@ -98,6 +96,10 @@ def test_curriculum_learning(self): } } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} def data_post_process(data, data_sampler_state_dict): assert 'dummy_metric' in data_sampler_state_dict['current_difficulties'] @@ -105,7 +107,7 @@ def data_post_process(data, data_sampler_state_dict): hidden_dim = 10 model = SimpleModel(hidden_dim) - dataset = random_dataset(20, hidden_dim, torch.device('cpu'), dtype=torch.half) + dataset = random_dataset(20, hidden_dim, torch.device('cpu')) model, _, data_loader, _ = deepspeed.initialize(config=config_dict, model=model, training_data=dataset, @@ -128,6 +130,8 @@ class TestLegacyCurriculumScheduler(DistributedTest): world_size = 2 def test_fixed_discrete(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -139,11 +143,6 @@ def test_fixed_discrete(self): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, "curriculum_learning": { "enabled": True, "curriculum_type": "seqlen", @@ -156,6 +155,10 @@ def test_fixed_discrete(self): } } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 ground_truths = {1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4} @@ -172,6 +175,8 @@ def test_fixed_discrete(self): assert seqlen == true_seqlen, f"Incorrect curriculum schedule" def test_fixed_linear(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -183,11 +188,6 @@ def test_fixed_linear(self): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, "curriculum_learning": { "enabled": True, "curriculum_type": "seqlen", @@ -200,6 +200,10 @@ def test_fixed_linear(self): } } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 ground_truths = {1: 2, 2: 4, 3: 4, 4: 6, 5: 6, 6: 8, 7: 8, 8: 10, 9: 10, 10: 10} diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index 880282bb7e57..c11c63d04867 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -47,9 +47,6 @@ def base_config(): "lr": 0.00015 } }, - "fp16": { - "enabled": True - } } return config_dict @@ -163,11 +160,19 @@ class TestConfigLoad(DistributedTest): world_size = 1 def test_dict(self, base_config): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=base_config, model=model, model_parameters=model.parameters()) def test_json(self, base_config, tmpdir): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} config_path = os.path.join(tmpdir, "config.json") with open(config_path, 'w') as fp: json.dump(base_config, fp) @@ -176,6 +181,10 @@ def test_json(self, base_config, tmpdir): model, _, _, _ = deepspeed.initialize(config=config_path, model=model, model_parameters=model.parameters()) def test_hjson(self, base_config, tmpdir): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} config_path = os.path.join(tmpdir, "config.json") with open(config_path, 'w') as fp: hjson.dump(base_config, fp) @@ -188,6 +197,10 @@ class TestDeprecatedDeepScaleConfig(DistributedTest): world_size = 1 def test(self, base_config, tmpdir): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} config_path = create_config_from_dict(tmpdir, base_config) parser = argparse.ArgumentParser() args = parser.parse_args(args='') @@ -209,6 +222,10 @@ class TestDistInit(DistributedTest): world_size = 1 def test(self, base_config): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -227,6 +244,12 @@ class TestInitNoOptimizer(DistributedTest): world_size = 1 def test(self, base_config): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + if get_accelerator().device_name() == "cpu": + pytest.skip("This test timeout with CPU accelerator") del base_config["optimizer"] hidden_dim = 10 @@ -246,6 +269,10 @@ class TestArgs(DistributedTest): world_size = 1 def test_none_args(self, base_config): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} model = SimpleModel(hidden_dim=10) model, _, _, _ = deepspeed.initialize(args=None, model=model, config=base_config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device) @@ -253,6 +280,10 @@ def test_none_args(self, base_config): loss = model(batch[0], batch[1]) def test_no_args(self, base_config): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} model = SimpleModel(hidden_dim=10) model, _, _, _ = deepspeed.initialize(model=model, config=base_config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device) @@ -264,6 +295,10 @@ class TestNoModel(DistributedTest): world_size = 1 def test(self, base_config): + if get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} model = SimpleModel(hidden_dim=10) with pytest.raises(AssertionError): model, _, _, _ = deepspeed.initialize(model=None, config=base_config) diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 8ec9f05a0a17..1f3cc991eba7 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -18,6 +18,7 @@ from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR from deepspeed.runtime.config import ADAM_OPTIMIZER from deepspeed.runtime.utils import see_memory_usage, required_torch_version +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -30,9 +31,6 @@ def test(self, zero_stage): ds_config = { 'train_batch_size': self.world_size, - 'fp16': { - 'enabled': True - }, 'zero_optimization': { "stage": zero_stage, "offload_param": { @@ -40,6 +38,10 @@ def test(self, zero_stage): } } } + if get_accelerator().is_fp16_supported(): + ds_config["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + ds_config["bf16"] = {"enabled": True} # 20B test #hidden_dim = 16 * 1024 hidden_dim = 4 @@ -49,11 +51,7 @@ def test(self, zero_stage): see_memory_usage('pre-init', force=True) model, _, _, _ = deepspeed.initialize(model=model, config=ds_config) see_memory_usage('post-init', force=True) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device, - dtype=torch.half) + data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for batch in data_loader: model(batch[0], batch[1]) see_memory_usage('post-fwds', force=True) @@ -120,6 +118,9 @@ class TestOptimizerImplementation(DistributedTest): reuse_dist_env = True def test(self, optimizer_extension, model_dtype, grad_accum_dtype): + if not get_accelerator().is_fp16_supported(): + if model_dtype == 'fp16' or grad_accum_dtype == 'fp16': + pytest.skip("fp16 is not supported") if optimizer_extension == 'zero1': zero_stage = 1 elif optimizer_extension == 'zero2': diff --git a/tests/unit/runtime/test_multi_output_model.py b/tests/unit/runtime/test_multi_output_model.py index d9aba419b158..cda0d4f054d3 100644 --- a/tests/unit/runtime/test_multi_output_model.py +++ b/tests/unit/runtime/test_multi_output_model.py @@ -5,8 +5,9 @@ import torch import deepspeed +from deepspeed.accelerator import get_accelerator from pytest import approx -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.multi_output_model import MultiOutputModel, multi_output_dataloader @@ -28,10 +29,11 @@ def test(self, tmpdir): "lr": 0.00015 } }, - "fp16": { - "enabled": True - } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 weight_value = 0.1 @@ -53,7 +55,7 @@ def test(self, tmpdir): inputs, targets = batch[:midpoint], batch[midpoint:] loss_tuple = model(inputs, targets) - expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device) + expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device) for loss in loss_tuple: assert loss.shape == torch.Size([]) assert loss.item() == approx(expected_loss.item()) @@ -84,10 +86,11 @@ def test(self, tmpdir): "lr": 0.00015 } }, - "fp16": { - "enabled": True - } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 weight_value = 0.1 @@ -111,7 +114,7 @@ def test(self, tmpdir): loss_tuple = model(inputs, targets) assert len(loss_tuple) == 3 - expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device) + expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device) for loss in loss_tuple: assert loss.shape == torch.Size([]) diff --git a/tests/unit/runtime/test_mup_optimizers.py b/tests/unit/runtime/test_mup_optimizers.py index ebecf73d416f..7666fa9d1c1f 100644 --- a/tests/unit/runtime/test_mup_optimizers.py +++ b/tests/unit/runtime/test_mup_optimizers.py @@ -10,6 +10,7 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataloader from mup.shape import set_base_shapes +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize("optimizer, expected_opt_class", [("MuAdam", torch.optim.Adam), @@ -31,14 +32,15 @@ def test(self, optimizer, expected_opt_class, zero_offload): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 2, "cpu_offload": zero_offload } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) set_base_shapes(model, None) diff --git a/tests/unit/runtime/test_pld.py b/tests/unit/runtime/test_pld.py index 1f602db73b2f..f6da992d5e11 100644 --- a/tests/unit/runtime/test_pld.py +++ b/tests/unit/runtime/test_pld.py @@ -10,6 +10,7 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel, PLD_SimpleModel, random_dataloader +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) @@ -39,15 +40,16 @@ def test_pld_model(self, theta): "lr": 0.0001 } }, - "fp16": { - "enabled": True - }, "progressive_layer_drop": { "enabled": True, "theta": theta, "gamma": gamma } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = PLD_SimpleModel(hidden_dim, empty_grad=False) @@ -80,15 +82,16 @@ def test_non_pld_model(self): "lr": 0.0001 } }, - "fp16": { - "enabled": True - }, "progressive_layer_drop": { "enabled": True, "theta": theta, "gamma": gamma } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim, empty_grad=False) diff --git a/tests/unit/runtime/zero/test_ignore_unused_parameters.py b/tests/unit/runtime/zero/test_ignore_unused_parameters.py index aade488fde42..b1d341486e55 100644 --- a/tests/unit/runtime/zero/test_ignore_unused_parameters.py +++ b/tests/unit/runtime/zero/test_ignore_unused_parameters.py @@ -9,6 +9,7 @@ from deepspeed.ops.op_builder import CPUAdamBuilder import deepspeed +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize('ignore_unused_parameters', [False, True]) @@ -36,11 +37,11 @@ def test(self, ignore_unused_parameters): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + else: + config_dict["bf16"] = {"enabled": True} hidden_dim = 4 model = UnusedParametersModel(hidden_dim=hidden_dim) diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 5a8af95bb0f8..7262a1b2c998 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -16,7 +16,7 @@ from torch.nn.parameter import Parameter from torch.nn.utils import skip_init -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import SimpleModel, random_dataloader import deepspeed @@ -71,11 +71,11 @@ def test(self, zero_stage): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @@ -91,6 +91,8 @@ class TestZero3RepeatForwardLoop(DistributedTest): world_size = 1 def test(self, mics_enabled, zero_stage=3): + if mics_enabled and get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") # force all params to be partitioned by forcing threshold=0 mics_shard_size = -1 if mics_enabled: @@ -111,11 +113,11 @@ def test(self, mics_enabled, zero_stage=3): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 4 class AlbertLikeModel(torch.nn.Module): @@ -166,11 +168,11 @@ def test_1_param_group(self, tmpdir, zero_stage, freeze_params): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} class MyModel(torch.nn.Module): @@ -260,11 +262,11 @@ def test_2_param_groups(self, tmpdir, zero_stage, freeze_params): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} class MyModel(torch.nn.Module): @@ -366,11 +368,11 @@ def test(self, allgather_bucket_size, zero_stage=2): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @@ -401,11 +403,11 @@ def test(self, zero_stage=2): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @@ -625,6 +627,8 @@ def test_param_persistence_threshold(self, param_persistence_threshold): @pytest.mark.parametrize("fp16_enabled", [True, False]) def test_fp16_enabled(self, fp16_enabled): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") self._test(fp16_enabled=fp16_enabled) @pytest.mark.parametrize("contiguous_gradients", [True, False]) @@ -690,11 +694,11 @@ def _test( "lr": 1.0 } }, - "fp16": { - "enabled": fp16_enabled, - "loss_scale": 1.0, - }, } + if get_accelerator().is_fp16_supported(): + cfg["fp16"] = {"enabled": True, "loss_scale": 1.0} + elif get_accelerator().is_bf16_supported(): + cfg["bf16"] = {"enabled": True} if offload_optimizer: cfg["zero_optimization"]["offload_optimizer"] = { @@ -859,11 +863,11 @@ def forward(self, x: Tensor) -> Tensor: "lr": 1.0 } }, - "fp16": { - "enabled": True, - "loss_scale": 1.0, - }, } + if get_accelerator().is_fp16_supported(): + ds_config["fp16"] = {"enabled": True, "loss_scale": 1.0} + elif get_accelerator().is_bf16_supported(): + ds_config["bf16"] = {"enabled": True} with deepspeed.zero.Init(mem_efficient_linear=False, enabled=init_context_manager): model = LargeParamModel() ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config) @@ -938,24 +942,24 @@ def forward(self, x: Tensor) -> Tensor: "lr": 1.0 } }, - "fp16": { - "enabled": True, - "loss_scale": 1.0, - }, } + if get_accelerator().is_fp16_supported(): + ds_cfg["fp16"] = {"enabled": True, "loss_scale": 1.0} + elif get_accelerator().is_bf16_supported(): + ds_cfg["bf16"] = {"enabled": True} with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=init_context_manager): model = ManyParamModel() ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg) + dtype = preferred_dtype() for _ in range(3): # test multiple iterations to cover prefetching - activations: List[Tensor] = ds_engine( - torch.ones((param_sz, ), dtype=torch.float16, device=ds_engine.device)) + activations: List[Tensor] = ds_engine(torch.ones((param_sz, ), dtype=dtype, device=ds_engine.device)) assert len(activations) == n_layers partition_sz = math.ceil(param_sz / self.world_size) - expected_activations = torch.empty(param_sz, dtype=torch.float16, device=ds_engine.device) + expected_activations = torch.empty(param_sz, dtype=dtype, device=ds_engine.device) for start_idx in range(0, param_sz, partition_sz): expected_activations[start_idx:start_idx + partition_sz] = dist.get_rank() @@ -1007,11 +1011,11 @@ def __init_weights(self, module): "lr": 1.0 } }, - "fp16": { - "enabled": True, - "loss_scale": 1.0, - }, } + if get_accelerator().is_fp16_supported(): + ds_cfg["fp16"] = {"enabled": True, "loss_scale": 1.0} + elif get_accelerator().is_bf16_supported(): + ds_cfg["bf16"] = {"enabled": True} with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=True): model = ModelWhereParentInitializesChildWeights() @@ -1207,13 +1211,14 @@ def test(self): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 3 }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 class SubModel(torch.nn.Module): @@ -1284,9 +1289,6 @@ def test(self): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 1, "offload_optimizer": { @@ -1294,6 +1296,10 @@ def test(self): } }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -1311,6 +1317,8 @@ class TestZero3DictFwd(DistributedTest): world_size = 1 def test(self, return_type): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -1320,13 +1328,14 @@ def test(self, return_type): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 3 }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 class MyModel(torch.nn.Module): @@ -1391,11 +1400,11 @@ def test(self, zero_stage): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim, nlayers=12) @@ -1445,13 +1454,14 @@ def test(self, zero_stage): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 class MyModel(torch.nn.Module): @@ -1497,9 +1507,6 @@ def test(self, force_ds_optim): "train_batch_size": 4, "gradient_accumulation_steps": 2, "steps_per_print": 1, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 1, "offload_optimizer": { @@ -1508,6 +1515,10 @@ def test(self, force_ds_optim): }, "zero_force_ds_cpu_optimizer": force_ds_optim, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -1529,15 +1540,15 @@ def test_training_partition_cache(self, training): hidden_dim = 10 config_dict = { "train_batch_size": 2, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 3, "stage3_param_persistence_threshold": hidden_dim, }, } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} if training: config_dict["optimizer"] = {"type": "Adam"} @@ -1546,13 +1557,11 @@ def test_training_partition_cache(self, training): model, _, _, _ = deepspeed.initialize(model=model, config=config_dict) - dtype = torch.half data_loader = random_dataloader( model=model, total_samples=6, hidden_dim=hidden_dim, device=model.device, - dtype=dtype, ) for _, batch in enumerate(data_loader): @@ -1576,6 +1585,8 @@ class TestEmptyParameterGroup(DistributedTest): world_size = 1 def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_group): + if dtype == torch.float16 and not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") model = SimpleModel(hidden_dim=4, nlayers=4) param_groups = [ { diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 0ddf1026eaf8..ec9e9e94aeaf 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -6,11 +6,13 @@ from types import SimpleNamespace import torch +import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import SimpleModel from utils import setup_serial_env @@ -47,16 +49,17 @@ def forward(self, x): "lr": 0.00015 } }, - "fp16": { - "enabled": True, - "loss_scale": 138. - }, "zero_optimization": { "stage": 3, "stage3_param_persistence_threshold": 1, } } +if get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} +elif get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + class TestZeroGatheredParametersFree(DistributedTest): world_size = 1 @@ -124,6 +127,8 @@ def test_scattered_init_dist(self): assert dist.is_initialized() def test_scatter_halftype(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") setup_serial_env() with deepspeed.zero.Init(): @@ -248,7 +253,7 @@ def forward(self, input): with deepspeed.zero.GatheredParameters(net.linear1.weight): assert net.linear1.weight.numel() == net.dim**2 - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) engine.backward(loss) engine.step() diff --git a/tests/unit/runtime/zero/test_zero_context_return.py b/tests/unit/runtime/zero/test_zero_context_return.py index 874a8ea3b676..9d49b6d3ba88 100644 --- a/tests/unit/runtime/zero/test_zero_context_return.py +++ b/tests/unit/runtime/zero/test_zero_context_return.py @@ -8,9 +8,10 @@ import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.accelerator import get_accelerator from utils import setup_serial_env -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype class DanglingBias(torch.nn.Linear): @@ -119,16 +120,17 @@ def forward(self, input): "lr": 0.00015 } }, - "fp16": { - "enabled": True, - "loss_scale": 138. - }, "zero_optimization": { "stage": 3, "stage3_param_persistence_threshold": 1, } } +if get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} +elif get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + class TestReturnParam(DistributedTest): world_size = 1 @@ -142,7 +144,7 @@ def test_ext_param_return(self): engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config) for _ in range(5): - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) engine.backward(loss) engine.step() @@ -158,7 +160,7 @@ def test_ext_param_returnobj(self): engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config) for _ in range(5): - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) assert len(net._external_params) == 1 assert len(net.dangler._external_params) == 0 @@ -176,7 +178,7 @@ def test_stage_3_output_type(self, output_type): engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config) for _ in range(1): - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) if loss is not None: if isinstance(loss, dict): diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 0855acec57e3..1d3b88a04a4e 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -6,11 +6,12 @@ import deepspeed.comm as dist import torch -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import random_dataloader import deepspeed from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module +from deepspeed.accelerator import get_accelerator class ChooseModuleByCounter(torch.nn.Module): @@ -89,9 +90,6 @@ def _test_set_z3_leaf_modules(self, cls, requires_grad): "lr": 1e-6 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 3, "stage3_prefetch_bucket_size": hidden_dim**2, @@ -99,6 +97,10 @@ def _test_set_z3_leaf_modules(self, cls, requires_grad): "stage3_max_reuse_distance": 0, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} model = cls(hidden_dim) @@ -106,7 +108,7 @@ def _test_set_z3_leaf_modules(self, cls, requires_grad): set_z3_leaf_modules(model, [cls]) assert z3_leaf_module(model) - run_model(model, config_dict, hidden_dim, torch.float16, requires_grad) + run_model(model, config_dict, hidden_dim, preferred_dtype(), requires_grad) def test_choose_module_by_counter(self): self._test_set_z3_leaf_modules(ChooseModuleByCounter, True) diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index b3adfdf96c50..3bb4af3e3d91 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -7,7 +7,7 @@ import deepspeed.comm as dist import torch -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import random_dataloader, SimpleModel from unit.util import bf16_required_version_check @@ -18,6 +18,7 @@ from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator WEIGHT_KEY = 'weight' FIRST_ORDER_KEY = 'exp_avg' @@ -112,14 +113,14 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, froz "lr": 1e-6 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 2 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 2} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} if offload_device == OffloadDeviceEnum.cpu: config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} @@ -139,9 +140,12 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, froz validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) - run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_after_bwd, validate_after_step) + run_fragmented_model(model, config_dict, hidden_dim, preferred_dtype(), validate_after_bwd, + validate_after_step) def test_bf16_fragments(self, frozen_weights): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") if frozen_weights: pytest.skip("TODO: Frozen weights not currently supported by BF16 Optimizer") @@ -302,6 +306,8 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtyp } if dtype == torch.float16: + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} elif dtype == torch.bfloat16: config_dict["bf16"] = {"enabled": True} diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 01ce3d2fe4c9..3357c200bd68 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -14,6 +14,7 @@ from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist +from .common import preferred_dtype class SimpleModel(torch.nn.Module): @@ -262,21 +263,21 @@ def forward(self, x, y, **kwargs): return hidden_dim -def random_dataset(total_samples, hidden_dim, device, dtype=torch.half): +def random_dataset(total_samples, hidden_dim, device, dtype=preferred_dtype()): train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) train_dataset = torch.utils.data.TensorDataset(train_data, train_label) return train_dataset -def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half): +def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype()): batch_size = model.train_micro_batch_size_per_gpu() train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) return train_loader -def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=torch.half): +def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype()): batch_size = model.train_micro_batch_size_per_gpu() train_data = torch.randn(total_samples, seq_len, hidden_dim, device=device, dtype=dtype) train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)