Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make op builder detection adapt to accelerator change #5206

Merged
merged 60 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
238bd1c
update nv-inference.yml for launch time op builder detection validation
delock Feb 26, 2024
4382ebe
change accelerator detection logic
delock Feb 26, 2024
cc21d46
fallback to gloo when oneccl_binding_for_pytorch is not installed
delock Feb 26, 2024
b746322
add a workflow to test opbuilder-update
delock Feb 26, 2024
a9aab4d
remove triton from dependenc
delock Feb 26, 2024
53fc44a
remove opbuild-update and change nv-inference.yml only
delock Feb 28, 2024
e5533ba
make installed_ops check accelerator name consistency
delock Feb 28, 2024
ef01f0d
Merge branch 'master' into gma/launch_opbuilder_detection
tjruwase Feb 28, 2024
22cc43c
fix accelerator override name
delock Feb 29, 2024
1691ef3
fix formatting check
delock Feb 29, 2024
cf2ea66
regenerate compatible ops every time
delock Feb 29, 2024
c040105
remove ipex and oneccl_pt_binding installation in cpu-inference workflow
delock Feb 29, 2024
8259122
fix cpu-inference and nv-inference workflow
delock Feb 29, 2024
d13fe5c
add missing quotation mark
delock Feb 29, 2024
ccaeb72
import ALL_OPS in git_version_info.py
delock Feb 29, 2024
2b6707f
build oneCCL with parallel make -j
delock Feb 29, 2024
a3bc2f8
adding missing package dependency
delock Mar 1, 2024
04bd061
fix cpu-inference workflow
delock Mar 1, 2024
9116302
Merge branch 'master' into gma/launch_opbuilder_detection
tjruwase Mar 1, 2024
1a1a71b
fix cpu-inference workflow and pre-compile workflow
delock Mar 2, 2024
87367e1
remove py-cpuinfo and psutil preinstall
delock Mar 2, 2024
52fc101
Merge branch 'master' into gma/launch_opbuilder_detection
tjruwase Mar 2, 2024
8baa89e
Merge branch 'master' into gma/launch_opbuilder_detection
loadams Mar 4, 2024
0a63463
Merge branch 'master' into gma/launch_opbuilder_detection
loadams Mar 5, 2024
4bba5e1
Skip test when its fp16
delock Mar 5, 2024
7cd08cd
fix elastic test
delock Mar 5, 2024
b28d81b
Better dequantization skipping
delock Mar 5, 2024
2ea44ba
fix format
delock Mar 5, 2024
57bab57
add numactl into dependency
delock Mar 6, 2024
f06f6b6
Merge branch 'master' into gma/launch_opbuilder_detection
loadams Mar 6, 2024
47888ec
Use bf16 data type for test if accelerator does not support fp16
delock Mar 7, 2024
a317fe8
skip more tests requires bf16
delock Mar 7, 2024
a03cc56
skip more UTs
delock Mar 7, 2024
cd916f7
skip more tests that CPU accelerator does not support
delock Mar 7, 2024
c943ec2
change skip reason
delock Mar 7, 2024
30d3e69
skip a time out test
delock Mar 7, 2024
2658d41
fix test_zero
delock Mar 7, 2024
cd8672d
Get around lazy init issue in test_ds_config_dict.py
delock Mar 7, 2024
a1666ba
Merge branch 'master' into gma/launch_opbuilder_detection
loadams Mar 7, 2024
ff5380f
fix more ut failures
delock Mar 8, 2024
da808a2
fix more UT failure
delock Mar 8, 2024
45e146a
fix more UTs
delock Mar 8, 2024
2fd32b6
fix more tests
delock Mar 8, 2024
2e60462
better construct for preferred dtype
delock Mar 8, 2024
c754492
fix import error
delock Mar 8, 2024
e024e6f
remove scale for bf16 config
delock Mar 8, 2024
f55186d
pass more UTs
delock Mar 8, 2024
02cb9e3
fix more tests
delock Mar 8, 2024
ae544e1
Merge branch 'master' into gma/launch_opbuilder_detection
loadams Mar 8, 2024
4623622
change preferred_dtype into a function
delock Mar 8, 2024
43505ab
install pdsh in cpu-torch-latest.yml
delock Mar 9, 2024
79c4d6c
Merge branch 'master' into gma/launch_opbuilder_detection
tjruwase Mar 9, 2024
2e59d92
better test_lr_scheduler skipping
delock Mar 11, 2024
ad351e4
skip multinmode test
delock Mar 11, 2024
b2673df
preferred_dtype ==> preferred_dtype()
delock Mar 11, 2024
41ced03
fix more tests
delock Mar 11, 2024
ad19171
skip some special case
delock Mar 11, 2024
f4fe02b
fix error in nv-torch-latest
delock Mar 11, 2024
88567b3
fix error in test_zero_context
delock Mar 12, 2024
c94003b
remove "fp16" argument in checkpoint_correctness_verification
delock Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def _set_dtype(self, ds_config, dtype):
else:
self.dtype = torch.float
else:
self.dtype = dtype or torch.half
self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported(
) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32

def patch_init_and_builtins(self):

Expand Down
22 changes: 4 additions & 18 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class TestZeROCheckpoint(DistributedTest):

@pytest.mark.parametrize('zero_stage', [3])
def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"optimizer": {
Expand Down Expand Up @@ -53,8 +51,6 @@ def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage):
def test_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -95,8 +91,6 @@ def test_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_op
def test_not_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -133,8 +127,6 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, ada

@pytest.mark.parametrize('zero_stage', [1, 2])
def test_hybrid_optimizer_state(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
Expand All @@ -161,8 +153,8 @@ def test_hybrid_optimizer_state(self, tmpdir, zero_stage):

@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
def test_load_module_only(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
if zero_stage == 0 and get_accelerator().device_name() == "cpu":
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip("CPU Accelerator does not support this test")
config_dict = {
"train_batch_size": 2,
"optimizer": {
Expand Down Expand Up @@ -336,8 +328,8 @@ def test_immediate_save_load(self, tmpdir, zero_stage):

@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
def test_load_immediate_save(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
if zero_stage == 0 and get_accelerator().device_name() == "cpu":
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip("CPU Accelerator does not support this test")
config_dict = {
"train_batch_size": 4,
"optimizer": {
Expand Down Expand Up @@ -421,8 +413,6 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
def test_load_optimizer_state(self, tmpdir, zero_stage):

if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
Expand Down Expand Up @@ -454,8 +444,6 @@ def test_load_optimizer_state(self, tmpdir, zero_stage):
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
def test_not_load_optimizer_state(self, tmpdir, zero_stage):

if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
Expand Down Expand Up @@ -485,8 +473,6 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage):

@pytest.mark.parametrize('zero_stage', [1, 2, 3])
def test_load_module_only(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"optimizer": {
Expand Down
Loading