Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: no sync context manager is incompatible with gradientpartitioning logic of ZeRo stage 3 #6793

Open
66RomanReigns opened this issue Nov 26, 2024 · 13 comments

Comments

@66RomanReigns
Copy link

I encountered an issue while using DeepSpeed with ZeRO Stage 3 optimization. I received the following error: no_sync is not compatible with ZeRO Stage 3. I’m not sure how to resolve this conflict.

If anyone has experience with this or knows how to resolve it, could you please guide me? Thank you in advance!

[rank0]: File "/root/miniconda3/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1997, in no_sync
[rank0]: assert not self.zero_optimization_partition_gradients(),
[rank0]: AssertionError: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
0%| | 0/168 [00:00<?, ?it/s]
W1126 23:28:07.821000 402381 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 402434 closing signal SIGTERM
E1126 23:28:11.641000 402381 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 1 (pid: 402435) of binary: /root/miniconda3/envs/llama/bin/python

@WeiluXu
Copy link

WeiluXu commented Nov 26, 2024

My guess is wrong, please see thehir0's reply

@thehir0
Copy link

thehir0 commented Nov 26, 2024

my code snippet:

    def _broadcast_to_vllm(self, model: DeepSpeedEngine):
        # avoid OOM
        torch.cuda.empty_cache()
        model = model.module
        count, num_params = 0, len(list(model.named_parameters()))
        for name, param in model.named_parameters():
            count += 1  # empty_cache at last param

            # Fire all vllm engines for broadcast
            if torch.distributed.get_rank() == 0:
                shape = param.shape if self.accelerator.deepspeed_plugin.zero_stage != 3 else param.ds_shape
                refs = [
                    engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
                    for engine in self.vllm_engines
                ]

            # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
            with deepspeed.zero.GatheredParameters([param], enabled=self.accelerator.deepspeed_plugin.zero_stage == 3):
                if torch.distributed.get_rank() == 0:
                    torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
                    ray.get(refs)

with deepspeed version 0.16.0 i have same error on: deepspeed.zero.GatheredParameters([param], enabled=self.accelerator.deepspeed_plugin.zero_stage == 3)

with deepspeed version 0.15.4:

_broadcast_to_vllm
    with deepspeed.zero.GatheredParameters([param], enabled=self.accelerator.deepspeed_plugin.zero_stage == 3):
  File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2241, in __exit__
    self.params[0].partition(param_list=self.params, has_been_updated=False)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1386, in partition
    self._partition(param_list, has_been_updated=has_been_updated)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1535, in _partition
    self._partition_param(param, has_been_updated=has_been_updated)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1568, in _partition_param
    free_param(param)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 284, in free_param
    assert not param.ds_active_sub_modules, param.ds_summary()
AssertionError: {'id': 0, 'status': 'AVAILABLE', 'numel': 544997376, 'ds_numel': 544997376, 'shape': (152064, 3584), 'ds_shape': (152064, 3584), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([34062336])}

Everything works if grad_accum = 1, if grad_accum > 1, then these errors occur

@hynnn
Copy link

hynnn commented Nov 27, 2024

Same problem....
My training configuration hasn't changed, it worked yesterday, but today it doesn't 🚬
Has it been resolved?

@LaoWangGB
Copy link

use deepspeed==0.15.4 solve the problem.

@yejoon-lee
Copy link

I faced the same error with deepspeed==0.16.0, but it seems to be fine with deepspeed==0.15.4

@yuanyangeli
Copy link

use deepspeed==0.15.4 solve the problem.

it's work

@Luxanna-Real
Copy link

I faced the same error with deepspeed==0.16.0, but it seems to be fine with deepspeed==0.15.4

Thank you, this is very helpful.

@inkcherry
Copy link
Contributor

same issue in Zero3 training, it was likely related to this #6675

@samar-khanna
Copy link

samar-khanna commented Dec 2, 2024

@66RomanReigns I think this issue should be re-opened-- downgrading the version is not a long term fix. And it's also a problem for ZeRO Stage 2.

@allblueee
Copy link

Same problem but ZeRo stage 2. Solved by using deepspeed==0.15.4. Thx~

@dalan2014
Copy link

Fixed this issue by setting gradient_accumulation_steps=1 while using deepspeed==0.16.0.

@tjruwase
Copy link
Contributor

tjruwase commented Dec 4, 2024

@66RomanReigns, @allblueee, @inkcherry the reason for this assertion is that no_sync context manager is meant to disable gradient reduction during the backward pass. However, this behavior conflicts with the gradient partitioning of ZeRO2 & ZeRO3 which requires gradient reduction. That is why we added the assertion to properly support no_sync context manager.

Can you explain why you need no_sync context manager in your code?

@tjruwase
Copy link
Contributor

tjruwase commented Dec 4, 2024

@thehir0, can you please open a separate ticket for your issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests