Skip to content

Commit

Permalink
Make op builder detection adapt to accelerator change (#5206)
Browse files Browse the repository at this point in the history
This is an WIP PR that make op builder detection adapt to accelerator
change. This is followup of
#5173
Currently, DeepSpeed generate `installed_ops` and `compatible_ops` at
setup time. If the system change to a different accelerator at DeepSpeed
launch time, these two list would contain incorrect information.

This PR intend to solve this problem with more flexity ops detection.

* For `installed_ops`, DeepSpeed should disable all installed ops if
accelerator detected at setup time is different from launch time.
* For `compatible_ops`, DeepSpeed should refresh the list for each
launch to avoid impact of accelerator change.

In the first step, nv-inference workflow is temporary change to emulate
the scenario that the system is setup with CPU_Accelerator, then launch
with CUDA_Accelerator. And CPU_Accelerator is modified to make Intel
Extension for PyTorch and oneCCL binding for PyTorch not mandatory.

Starting from here we can reconstruct installed_ops and compatible_ops
to follow the design above.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent 535a908 commit c08e69f
Show file tree
Hide file tree
Showing 49 changed files with 567 additions and 305 deletions.
23 changes: 3 additions & 20 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,26 @@ jobs:
- name: Detect instruction sets on instance
run: |
lscpu
pip install cmake
git clone https://github.com/intel/intel-extension-for-pytorch
cd intel-extension-for-pytorch/tests/cpu/isa
cmake .
make
./cpu_features
- name: Install numactl
run: |
sudo apt-get install -y numactl
- name: Install oneCCL Bindings for PyTorch
- name: Install dependencies
run: |
pip install torch
python -m pip install intel_extension_for_pytorch
# the curl line is for troubleshooting
curl -L https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
python -m pip install oneccl_bind_pt --index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install py-cpuinfo
# check installed version
pip list |grep \\\<torch\\\>
pip list |grep intel-extension-for-pytorch
pip list |grep oneccl-bind-pt
- name: Install oneCCL
run: |
pip install cmake
git clone https://github.com/oneapi-src/oneCCL
cd oneCCL
mkdir build
cd build
cmake ..
make
make install
#source ./_install/env/setvars.sh
# test whether oneCCL is correctly installed
#mpirun -n 2 ./examples/benchmark/benchmark
make -j install
- name: Install transformers
run: |
Expand All @@ -103,7 +87,6 @@ jobs:
source oneCCL/build/_install/env/setvars.sh
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6
# check whether the environment is properly setup
python -c "import torch;import intel_extension_for_pytorch as ipex;import oneccl_bindings_for_pytorch;print('done')"
python -c "import deepspeed;from deepspeed.accelerator import get_accelerator;print(get_accelerator().device_name());print(get_accelerator().is_available())"
- name: Unit tests
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ jobs:
- id: setup-venv
uses: ./.github/workflows/setup-venv

- name: Install system packages
run: |
sudo apt-get install -y numactl pdsh
- name: Install pytorch
run: |
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ jobs:
- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning,inf,triton]
DS_ACCELERATOR=cpu pip install .[dev,1bit,autotuning,inf]
#pip install .[dev,1bit,autotuning,inf,triton]
ds_report
- name: Python environment
Expand All @@ -60,3 +61,5 @@ jobs:
#pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="2.1" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
# run ds_report again to check updated op list
ds_report
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
28 changes: 23 additions & 5 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
# DeepSpeed Team

import torch
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
import psutil
from .abstract_accelerator import DeepSpeedAccelerator

try:
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
oneccl_imported_p = True
except ImportError as e:
oneccl_imported_p = False

import os


Expand All @@ -15,8 +20,17 @@ class CPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'cpu'
self._communication_backend_name = 'ccl'
self.max_mem = psutil.Process().memory_info().rss
if oneccl_imported_p:
self._communication_backend_name = 'ccl'
else:
# fallback to gloo if oneccl_binding_for_pytorch is not installed
self._communication_backend_name = 'gloo'
try:
import psutil
mem = psutil.Process().memory_info().rss
self.max_mem = mem
except ImportError as e:
self.max_mem = 0

def is_synchronized_device(self):
return True
Expand Down Expand Up @@ -115,12 +129,14 @@ def empty_cache(self):
return

def get_rss(self):
import psutil
mem = psutil.Process().memory_info().rss
if mem > self.max_mem:
self.max_mem = mem
return mem

def reset_rss(self):
import psutil
mem = psutil.Process().memory_info().rss
self.max_mem = mem
return mem
Expand Down Expand Up @@ -166,9 +182,11 @@ def max_memory_reserved(self, device_index=None):
return self.max_mem

def total_memory(self, device_index=None):
import psutil
return psutil.virtual_memory().total

def available_memory(self, device_index=None):
import psutil
return psutil.virtual_memory().available

# Misc
Expand Down
24 changes: 18 additions & 6 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def get_accelerator():
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
)
elif accelerator_name == "cpu":
try:
import intel_extension_for_pytorch # noqa: F401 # type: ignore
except ImportError as e:
raise ValueError(
f"CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
pass
elif accelerator_name == "npu":
try:
import torch_npu # noqa: F401 # type: ignore
Expand Down Expand Up @@ -154,7 +150,23 @@ def get_accelerator():
except ImportError as e:
pass
if accelerator_name is None:
accelerator_name = "cuda"
# borrow this log from PR#5084
try:
import torch

# Determine if we are on a GPU or x86 CPU with torch.
if torch.cuda.is_available(): #ignore-cuda
accelerator_name = "cuda"
else:
if accel_logger is not None:
accel_logger.warn(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it."
)
accelerator_name = "cpu"
except (RuntimeError, ImportError) as e:
# TODO need a more decent way to detect which accelerator to use, consider using nvidia-smi command for detection
accelerator_name = "cuda"
pass

ds_set_method = "auto detect"

Expand Down
5 changes: 3 additions & 2 deletions deepspeed/env_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import subprocess
import argparse
from .ops.op_builder.all_ops import ALL_OPS
from .git_version_info import installed_ops, torch_info
from .git_version_info import installed_ops, torch_info, accelerator_name
from deepspeed.accelerator import get_accelerator

GREEN = '\033[92m'
Expand Down Expand Up @@ -51,7 +51,8 @@ def op_report(verbose=True):
for op_name, builder in ALL_OPS.items():
dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible(verbose) else no
is_installed = installed if installed_ops.get(op_name, False) else no
is_installed = installed if installed_ops.get(op_name,
False) and accelerator_name == get_accelerator()._name else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
Expand Down
11 changes: 10 additions & 1 deletion deepspeed/git_version_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,14 @@

from .ops.op_builder.all_ops import ALL_OPS
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
accelerator_name = ""
torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"}

# compatible_ops list is recreated for each launch
from .ops.op_builder.all_ops import ALL_OPS

compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
for op_name, builder in ALL_OPS.items():
op_compatible = builder.is_compatible()
compatible_ops[op_name] = op_compatible
compatible_ops["deepspeed_not_implemented"] = False
2 changes: 0 additions & 2 deletions deepspeed/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from . import adagrad
from . import lamb
from . import lion
#from ..git_version_info_installed import installed_ops as __installed_ops__
#if __installed_ops__['sparse_attn']:
from . import sparse_attention
from . import transformer

Expand Down
9 changes: 6 additions & 3 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self, param: Parameter) -> None:
self.__param = param

def wait(self) -> None:
get_accelerator().current_stream().synchronize()
if not get_accelerator().is_synchronized_device():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE


Expand All @@ -81,7 +82,8 @@ def wait(self) -> None:
if self.__complete:
return

get_accelerator().current_stream().synchronize()
if not get_accelerator().is_synchronized_device():
get_accelerator().current_stream().synchronize()
for param in self.__params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
param.ds_status = ZeroParamStatus.AVAILABLE
Expand Down Expand Up @@ -363,7 +365,8 @@ def _set_dtype(self, ds_config, dtype):
else:
self.dtype = torch.float
else:
self.dtype = dtype or torch.half
self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported(
) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32

def patch_init_and_builtins(self):

Expand Down
1 change: 1 addition & 0 deletions op_builder/all_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
__op_builders__.append(builder)

ALL_OPS = {op.name: op for op in __op_builders__ if op is not None}
accelerator_name = get_accelerator()._name
5 changes: 3 additions & 2 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,9 @@ def load(self, verbose=True):
if self.name in __class__._loaded_ops:
return __class__._loaded_ops[self.name]

from deepspeed.git_version_info import installed_ops, torch_info
if installed_ops.get(self.name, False):
from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name
from deepspeed.accelerator import get_accelerator
if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name:
# Ensure the op we're about to load was compiled with the same
# torch/cuda versions we are currently using at runtime.
self.validate_torch_version(torch_info)
Expand Down
5 changes: 3 additions & 2 deletions op_builder/xpu/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def fixed_aotflags(self):
]

def load(self, verbose=True):
from deepspeed.git_version_info import installed_ops, torch_info # noqa: F401
if installed_ops.get(self.name, False):
from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name # noqa: F401
from deepspeed.accelerator import get_accelerator
if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name:
return importlib.import_module(self.absolute_name())
else:
return self.jit_load(verbose)
Expand Down
8 changes: 2 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'Please visit https://pytorch.org/ to see how to properly install torch on your system.')

from op_builder import get_default_compute_capabilities, OpBuilder
from op_builder.all_ops import ALL_OPS
from op_builder.all_ops import ALL_OPS, accelerator_name
from op_builder.builder import installed_cuda_version

# Fetch rocm state.
Expand Down Expand Up @@ -168,12 +168,9 @@ def op_enabled(op_name):
return int(get_env_if_set(env_var, BUILD_OP_DEFAULT))


compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
install_ops = dict.fromkeys(ALL_OPS.keys(), False)
for op_name, builder in ALL_OPS.items():
op_compatible = builder.is_compatible()
compatible_ops[op_name] = op_compatible
compatible_ops["deepspeed_not_implemented"] = False

# If op is requested but not available, throw an error.
if op_enabled(op_name) and not op_compatible:
Expand Down Expand Up @@ -280,11 +277,10 @@ def create_dir_symlink(src, dest):
fd.write(f"git_hash='{git_hash}'\n")
fd.write(f"git_branch='{git_branch}'\n")
fd.write(f"installed_ops={install_ops}\n")
fd.write(f"compatible_ops={compatible_ops}\n")
fd.write(f"accelerator_name='{accelerator_name}'\n")
fd.write(f"torch_info={torch_info}\n")

print(f'install_requires={install_requires}')
print(f'compatible_ops={compatible_ops}')
print(f'ext_modules={ext_modules}')

# Parse README.md to make long_description for PyPI page.
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

from unit.common import preferred_dtype
from unit.simple_model import *
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -163,13 +164,15 @@ def checkpoint_correctness_verification(config_dict,
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False,
fp16=True,
train_batch=False,
base_optimizers=[None, None],
empty_tag=False,
seq_dataloader=False,
load_module_only=False):
dtype = torch.half if fp16 else torch.float32
load_module_only=False,
dtype=None):
if dtype == None:
dtype = preferred_dtype()

ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0])

if seq_dataloader:
Expand Down Expand Up @@ -241,7 +244,7 @@ def checkpoint_correctness_verification(config_dict,
load_module_only=load_module_only)

if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
compare_optimizer_states(trained_model, loaded_model, hidden_dim, dtype == torch.float16)

if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
4 changes: 2 additions & 2 deletions tests/unit/checkpoint/test_latest_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_existing_latest(self, tmpdir):
tmpdir=tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=False,
empty_tag=True)
empty_tag=True,
dtype=torch.float)

def test_missing_latest(self, tmpdir):
config_dict = {
Expand Down
Loading

0 comments on commit c08e69f

Please sign in to comment.