Skip to content

Commit

Permalink
Merge branch 'master' into cholmes/checkpoints-inference-v2-2
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 authored Nov 13, 2023
2 parents 601529e + 953e3e3 commit e69abe0
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 31 deletions.
2 changes: 1 addition & 1 deletion accelerator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# DeepSpeed Team

from .abstract_accelerator import DeepSpeedAccelerator
from .real_accelerator import get_accelerator, set_accelerator
from .real_accelerator import get_accelerator, set_accelerator, is_current_accelerator_supported
27 changes: 15 additions & 12 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
except ImportError as e:
dsa2 = None

SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']

ds_accelerator = None


Expand All @@ -34,14 +36,18 @@ def _validate_accelerator(accel_obj):
# accelerator.abstractor_accelerator
# or deepspeed.accelerator.abstract_accelerator, consider accel_obj
# is a conforming object
if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))):
if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))):
raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator")

# TODO: turn off is_available test since this breaks tests
# assert accel_obj.is_available(), \
# f'{accel_obj.__class__.__name__} accelerator fails is_available() test'


def is_current_accelerator_supported():
return get_accelerator() in SUPPORTED_ACCELERATOR_LIST


def get_accelerator():
global ds_accelerator
if ds_accelerator is not None:
Expand All @@ -50,7 +56,6 @@ def get_accelerator():
accelerator_name = None
ds_set_method = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
DS_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
if "DS_ACCELERATOR" in os.environ.keys():
accelerator_name = os.environ["DS_ACCELERATOR"]
if accelerator_name == "xpu":
Expand Down Expand Up @@ -79,15 +84,13 @@ def get_accelerator():
torch.mps.current_allocated_memory()
except (RuntimeError, ImportError) as e:
raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
elif accelerator_name == "cuda":
pass
else:
raise ValueError(
f'DS_ACCELERATOR must be one of {DS_ACCELERATOR_LIST}. Value "{accelerator_name}" is not supported')
elif is_current_accelerator_supported():
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
f'Value "{accelerator_name}" is not supported')
ds_set_method = "override"

# 2. If no override, detect which accelerator to use automatically
if accelerator_name == None:
if accelerator_name is None:
# We need a way to choose among different accelerator types.
# Currently we detect which accelerator extension is installed
# in the environment and use it if the installing answer is True.
Expand All @@ -105,21 +108,21 @@ def get_accelerator():
accelerator_name = "xpu"
except ImportError as e:
pass
if accelerator_name == None:
if accelerator_name is None:
try:
import intel_extension_for_pytorch # noqa: F401,F811 # type: ignore

accelerator_name = "cpu"
except ImportError as e:
pass
if accelerator_name == None:
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore

accelerator_name = "npu"
except ImportError as e:
pass
if accelerator_name == None:
if accelerator_name is None:
try:
import torch.mps

Expand All @@ -128,7 +131,7 @@ def get_accelerator():
accelerator_name = "mps"
except (RuntimeError, ImportError) as e:
pass
if accelerator_name == None:
if accelerator_name is None:
accelerator_name = "cuda"

ds_set_method = "auto detect"
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import logging
import os
import pickle
from typing import Any
from packaging import version

from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
Expand Down Expand Up @@ -94,6 +96,11 @@ def build_hf_engine(path: str,
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mistral":
from .model_implementations.mistral.policy import MistralPolicy
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
Expand Down
22 changes: 16 additions & 6 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@
TORCH_MINOR = int(torch.__version__.split('.')[1])


class MissingCUDAException(Exception):
pass


class CUDAMismatchException(Exception):
pass


def installed_cuda_version(name=""):
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
if cuda_home is None:
raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)")
# Ensure there is not a cuda version mismatch between torch and nvcc compiler
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
output_split = output.split()
Expand Down Expand Up @@ -89,9 +98,10 @@ def assert_no_cuda_mismatch(name=""):
"Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior."
)
return True
raise Exception(f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
f"version torch was compiled with {torch.version.cuda}, unable to compile "
"cuda/cpp extensions without a matching cuda version.")
raise CUDAMismatchException(
f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
f"version torch was compiled with {torch.version.cuda}, unable to compile "
"cuda/cpp extensions without a matching cuda version.")
return True


Expand Down Expand Up @@ -339,7 +349,7 @@ def is_cuda_enable(self):
try:
assert_no_cuda_mismatch(self.name)
return '-D__ENABLE_CUDA__'
except BaseException:
except MissingCUDAException:
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
"only cpu ops can be compiled!")
return '-D__DISABLE_CUDA__'
Expand Down Expand Up @@ -601,7 +611,7 @@ def builder(self):
if not self.is_rocm_pytorch():
assert_no_cuda_mismatch(self.name)
self.build_for_cpu = False
except BaseException:
except MissingCUDAException:
self.build_for_cpu = True

if self.build_for_cpu:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def cifar_trainset(fp16=False):


def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()]):
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()],
device_type=get_accelerator().device_name()):
ds_utils.set_random_seed(seed)

# disable dropout
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def set_accelerator_visible():
match = re.search('Device Type.*GPU', line)
if match:
num_accelerators += 1
elif get_accelerator().device_name() == 'npu':
npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
else:
assert get_accelerator().device_name() == 'cpu'
cpu_sockets = int(
Expand Down Expand Up @@ -204,13 +207,13 @@ def _dist_run(self, local_rank, num_procs, master_port):
if get_accelerator().is_available():
set_accelerator_visible()

if get_accelerator().is_available():
get_accelerator().set_device(local_rank)

if self.init_distributed:
deepspeed.init_distributed(dist_backend=self.backend)
dist.barrier()

if get_accelerator().is_available():
get_accelerator().set_device(local_rank)

try:
self.run(**self._fixture_kwargs)
except BaseException as e:
Expand Down
20 changes: 13 additions & 7 deletions tests/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@

import pytest
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported
from deepspeed.git_version_info import torch_info
from packaging import version as pkg_version


def skip_on_arch(min_arch=7):
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
if get_accelerator().device_name() == 'cuda':
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
pytest.skip(f"needs higher compute capability than {min_arch}")
else:
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
assert is_current_accelerator_supported()
return


def skip_on_cuda(valid_cuda):
split_version = lambda x: map(int, x.split('.')[:2])
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
if get_accelerator().device_name() == 'cuda':
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR
if valid_cuda.count(CUDA_VERSION) == 0:
pytest.skip(f"requires cuda versions {valid_cuda}")
else:
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
assert is_current_accelerator_supported()
return


Expand All @@ -43,8 +43,14 @@ def bf16_required_version_check(accelerator_check=True):
else:
accelerator_pass = True

if (TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and (
NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and accelerator_pass:
torch_version_available = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
cuda_version_available = CUDA_MAJOR >= 11
nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
npu_available = get_accelerator().device_name() == 'npu'

if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass:
return True
elif npu_available:
return True
else:
return False
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.12.3
0.12.4

0 comments on commit e69abe0

Please sign in to comment.