Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix cpu inference UT failure #4430

Merged
merged 50 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e60e645
add a white change that breaks formatting
delock Sep 28, 2023
ed95d21
fix TestModelTask
delock Sep 30, 2023
f0022b0
Skip TestModelTask if InferenceBuilder are not implemented
delock Oct 1, 2023
af2f380
remove blank change
delock Oct 1, 2023
257ed96
Merge branch 'master' into gma/fix_cpu_inference
loadams Oct 2, 2023
ac4254f
Reuse hf_model list among tests to avoid slow loading (#16)
delock Oct 7, 2023
cc0294f
Change COLUMNS to 140 to allow display of pytest skip message; Sanity…
delock Oct 8, 2023
af6661a
Merge branch 'master' into gma/fix_cpu_inference
delock Oct 8, 2023
861088f
Gma/fix cpu inference local (#19)
delock Oct 25, 2023
48787d9
change cpu inference test to self hosted v100 runner
delock Oct 25, 2023
f516fbd
Merge branch 'master' into gma/fix_cpu_inference
loadams Oct 25, 2023
17183bd
Running on self-hosted cpu rather than cuda machine.
delock Oct 26, 2023
34b2570
Merge branch 'master' into gma/fix_cpu_inference
loadams Oct 26, 2023
f40a484
remove ad-hoc running of cpu-inference
delock Oct 30, 2023
4ed3b60
update ccl.py for error type (#24)
Liangliang-Ma Oct 30, 2023
bac6bb6
Merge branch 'master' into gma/fix_cpu_inference
delock Oct 31, 2023
577b292
Merge branch 'master' into gma/fix_cpu_inference
loadams Oct 31, 2023
15295ae
install gcc-9 in cpu workflow
delock Nov 3, 2023
8d182cb
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 3, 2023
d52ff77
set gcc/g++ default to 9 in cpu inference workflow
delock Nov 4, 2023
0c6fa89
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 6, 2023
3dd7d34
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 6, 2023
e9fafa7
update oneccl_bind_pt installation steps
delock Nov 7, 2023
50bba12
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 8, 2023
1bd0dfb
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 13, 2023
5e41955
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 15, 2023
51922e4
mitigation for oneCCL GLIBCXX_3.4.30 not found issue
delock Nov 16, 2023
c98752b
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 16, 2023
fc6025c
use sudo to install conda package
delock Nov 18, 2023
a8cec8b
ccl issues fix (#32)
Liangliang-Ma Nov 21, 2023
9fb8ecb
clean up all_reduce_caching path
delock Nov 21, 2023
590c959
Merge branch 'master' into gma/fix_cpu_inference
delock Nov 21, 2023
c4cabcd
fix formatting
delock Nov 22, 2023
3663b75
preload libstdc++ from system lib path instead of conda path
delock Nov 23, 2023
7ca2ba5
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 28, 2023
3934919
prep oneCCL before running unit tests
delock Nov 29, 2023
4bf6493
Merge branch 'master' into gma/fix_cpu_inference
loadams Nov 29, 2023
b90fa99
prep libstdc++ in UT run
delock Dec 2, 2023
62d835c
Merge branch 'master' into gma/fix_cpu_inference
tjruwase Dec 4, 2023
8055034
allow codegen test for bf16
delock Dec 5, 2023
71d1106
Merge branch 'master' into gma/fix_cpu_inference
loadams Dec 5, 2023
b50a481
disable codegen bf16
delock Dec 7, 2023
3dce178
Merge branch 'master' into gma/fix_cpu_inference
loadams Dec 7, 2023
1596224
Merge branch 'master' into gma/fix_cpu_inference
loadams Jan 2, 2024
a72beea
fix test_inference_config UT error
delock Jan 3, 2024
057b6ff
Merge branch 'master' into gma/fix_cpu_inference
tjruwase Jan 3, 2024
5886645
Merge branch 'master' into gma/fix_cpu_inference
loadams Jan 3, 2024
3244e1f
fix typo
delock Jan 5, 2024
21b438c
Merge branch 'master' into gma/fix_cpu_inference
mrwyattii Jan 5, 2024
2e6fa99
Merge branch 'master' into gma/fix_cpu_inference
loadams Jan 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
name: cpu-inference

on:
pull_request:
paths-ignore:
- 'docs/**'
- 'blogs/**'
workflow_dispatch:
merge_group:
branches: [ master ]


concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
unit-tests:
runs-on: ubuntu-20.04
loadams marked this conversation as resolved.
Show resolved Hide resolved
runs-on: [self-hosted, cpu]

steps:
- uses: actions/checkout@v3

- id: setup-venv
uses: ./.github/workflows/setup-venv

- name: Install gcc-9
run: |
sudo add-apt-repository -u ppa:ubuntu-toolchain-r/test
sudo apt install -y gcc-9 g++-9
# set gcc-9 and g++9 to default
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 99
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-9 99

- name: Check gcc version
run: |
# Get gcc version
gcc --version
g++ --version

- name: Detect instruction sets on instance
run: |
lscpu
Expand All @@ -33,8 +54,16 @@ jobs:

- name: Install oneCCL Bindings for PyTorch
run: |
pip install torch
python -m pip install intel_extension_for_pytorch
python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu
# the curl line is for troubleshooting
curl -L https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install py-cpuinfo
# check installed version
pip list |grep \\\<torch\\\>
pip list |grep intel-extension-for-pytorch
pip list |grep oneccl-bind-pt

- name: Install oneCCL
run: |
Expand Down Expand Up @@ -62,14 +91,22 @@ jobs:
pip install .[dev,1bit,autotuning,inf]
ds_report

- name: Python environment
- name: Python environment check
run: |
pip list
source oneCCL/build/_install/env/setvars.sh
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6
# check whether the environment is properly setup
python -c "import torch;import intel_extension_for_pytorch as ipex;import oneccl_bindings_for_pytorch;print('done')"
loadams marked this conversation as resolved.
Show resolved Hide resolved
python -c "import deepspeed;from deepspeed.accelerator import get_accelerator;print(get_accelerator().device_name());print(get_accelerator().is_available())"

- name: Unit tests
run: |
# prep oneCCL for CCLBackend comm ops building
source oneCCL/build/_install/env/setvars.sh
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/
cd tests
# LOCAL_SIZE=2 enforce CPU to report 2 devices, this helps run the test on github default runner
LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/
4 changes: 2 additions & 2 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes)
}
}

void inference_all_reduce(torch::Tensor& data, py::object op, std::vector<int> group, bool async_op)
void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op)
{
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
Expand All @@ -562,7 +562,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op, std::vector<int> g
data.numel(),
get_ccl_datatype(data.scalar_type()),
get_ccl_reduce_op(op, data),
_get_comm_from_group(group))
_get_comm_from_group())
.wait());
return;
}
Expand Down
46 changes: 31 additions & 15 deletions deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def is_initialized(self):

def run_collective(self, name, **kwargs):
if name in self.available_coll:
kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
if 'group' in kwargs:
kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
if 'dst' in kwargs:
kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
if 'src' in kwargs:
Expand All @@ -71,23 +72,38 @@ def run_collective(self, name, **kwargs):
return CCLHandler(self.ccl_comm_op)
else:
func = "super(CCLBackend, self)." + name
return eval(func)(*(kwargs.values()))
eval(func)(*(kwargs.values()))
return CCLHandler(self.ccl_comm_op)

def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
return self.run_collective(name="all_reduce_caching",
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
name = "all_reduce_caching"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
else:
return self.run_collective(name=name,
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
else:
return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
name = "all_reduce"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
name = "inference_all_reduce"
if name in self.available_coll:
return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op)
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op)

def broadcast(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
Expand Down Expand Up @@ -120,11 +136,11 @@ def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes
input_split_sizes=input_split_sizes,
group=group)

def send(self, tensor, dst, group=None, async_op=False):
return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op)
def send(self, tensor, dst, group=None, tag=0):
return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, tag=tag)

def recv(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op)
def recv(self, tensor, src, group=None, tag=0):
return self.run_collective(name="recv", tensor=tensor, src=src, group=group, tag=tag)

def gather(self, tensor, gather_list, dst, group=None, async_op=False):
return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
Expand Down Expand Up @@ -170,7 +186,7 @@ def get_all_ranks_from_group(self, group):
while True:
results.append(super(CCLBackend, self).get_global_rank(group, rank))
rank += 1
except ValueError:
except (ValueError, RuntimeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the runtime error that we can hit here?

Copy link
Collaborator Author

@delock delock Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The while True: loop would iterate local ranks and collect global rank for these local ranks, until local rank is out of local rank range. In older version of PyTorch, this out-of-range will throw a ValueError. In PyTorch 2, this behavior will throw a RuntimeError.

@Liangliang-Ma

pass
if tuple(results) not in self.groups:
self._new_group(results, group)
Expand Down
2 changes: 1 addition & 1 deletion docs/_tutorials/accelerator-abstraction-interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ To run DeepSpeed model on CPU, use the following steps to prepare environment:

```
python -m pip install intel_extension_for_pytorch
python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu
python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu
git clone https://github.com/oneapi-src/oneCCL
cd oneCCL
mkdir build
Expand Down
18 changes: 15 additions & 3 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import time
import pickle
import torch
import pytest
import itertools
Expand Down Expand Up @@ -65,7 +66,13 @@
]

# Get a list of all models and mapping from task to supported models
_hf_models = list(HfApi().list_models())
try:
with open("hf_models.pkl", "rb") as fp:
_hf_models = pickle.load(fp)
except FileNotFoundError:
_hf_models = list(HfApi().list_models())
with open("hf_models.pkl", "wb") as fp:
pickle.dump(_hf_models, fp)
Comment on lines +69 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that caching the model list can be a good idea for the tests, but we need to save it to blob storage so that it is persistent. Additionally, I think the cache should have a timestamp connected to it, such that we update it every hour/day/week. See how we do this in MII:
https://github.com/microsoft/DeepSpeed-MII/blob/4472e4e206182ed56399f225848a7721565922fb/mii/utils.py#L39

_hf_model_names = [m.modelId for m in _hf_models]
_hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}

Expand Down Expand Up @@ -280,6 +287,12 @@ def test(
if invalid_test_msg:
pytest.skip(invalid_test_msg)

if dtype not in get_accelerator().supported_dtypes():
pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))

Expand Down Expand Up @@ -536,9 +549,8 @@ def test(
if dtype not in get_accelerator().supported_dtypes():
pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")

# TODO: enable this test after torch 2.1 stable release
if dtype == torch.bfloat16 and model_w_task[0] == "Salesforce/codegen-350M-mono":
pytest.skip("Codegen model(bf16) need to use torch version > 2.0.")
pytest.skip("Disable Codegen model(bf16) due to slight result difference")

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/inference/test_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestInferenceConfig(DistributedTest):
world_size = 1

def test_overlap_kwargs(self):
config = {"replace_with_kernel_inject": True}
config = {"replace_with_kernel_inject": True, "dtype": torch.float32}
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
kwargs = {"replace_with_kernel_inject": True}

engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
Expand All @@ -37,7 +37,7 @@ def test_kwargs_and_config(self):
assert engine._config.dtype == kwargs["dtype"]

def test_json_config(self, tmpdir):
config = {"replace_with_kernel_inject": True}
config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"}
config_json = create_config_from_dict(tmpdir, config)

engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)
Expand Down