You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When I enable drop_tokens for deepspeed.moe.layer.MoE the program hangs. If I use eval_capacity_factor it works but I would not like to overshoot the capacity factor to prevent token dropping, as I have noticed that deepspeed takes memory accordingly to the capacity factor set.
It will run no problem but if you uncomment the drop_tokens, the program halts
Expected behavior
I expect that the program runs.
ds_report output
[2024-11-29 16:34:46,436] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
[WARNING] NVIDIA Inference is only supported on Ampere and newer architectures
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
[WARNING] gds requires the dev libaio .so object and headers but these were not found.
[WARNING] gds: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
[WARNING] using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.4.0a0+f70bd71a48.nv24.06
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.15.3, unknown, unknown
torch cuda version ............... 12.5
torch hip version ................ None
nvcc version ..................... 12.5
deepspeed wheel compiled w. ...... torch 2.4, cuda 12.5
shared memory (/dev/shm) size .... 251.87 GB
System info (please complete the following information):
Ubuntu 22.04
1 Node DGX1 8xV100
Transformers 4.46.0
DeepSpeed 0.15.3
Python 3.10.12
Docker context
nvcr.io/nvidia/pytorch:24.06-py3
The text was updated successfully, but these errors were encountered:
The pipeline seems to be stuck at all_to_all communication, and the problem is due to inconsistent tensor shapes as a result of inconsistent capacity during inference across workers.
We used our tool to investigate potential anomalies in tensor shapes during the _AllToAll.apply operations for both dispatched_input and expert_output. Debugging statements inserted before and after the _AllToAll.apply calls revealed the issue.
Below is the debugging output (with num_gpus=2):
[DEBUG] Before ALL-to-ALL: dispatched_input shape torch.Size([4, 3, 32])
[DEBUG] Before ALL-to-ALL: dispatched_input shape torch.Size([4, 4, 32])
[DEBUG] After ALL-to-ALL: dispatched_input shape torch.Size([4, 3, 32]) # one process stalls here
[DEBUG] Before ALL-to-ALL: expert_output shape torch.Size([4, 3, 32]) # both processes stalls afterwards
The dispatched_input tensor shows inconsistent shapes across processes during the _AllToAll operation, suggesting that drop_tokens=False leads to mismatched tensor sizes. This inconsistency causes the program to silently stall.
Detailed Findings
Further investigation into the topXGating function revealed that:
Here, exp_counts represents the number of tokens assigned to each expert.
Since exp_counts depends on the logits distribution, which can vary across nodes due to slight differences in computation or rounding, the resulting new_capacity is inconsistent across processes.
This dynamic and non-deterministic capacity calculation leads to shape mismatches in tensors sent to the _AllToAll collective communication, causing the program to hang.
Suggested Fix
To ensure consistent behavior across nodes when drop_tokens=False:
Synchronize new_capacity across all nodes:
Replace the dynamic capacity computation with a collective operation to ensure all processes use the same
Add a step to synchronize tensor shapes before dispatching them to _AllToAll to ensure consistency.
Alternative Approach:
Consider reworking the logic in topXGating to guarantee deterministic new_capacity calculation independent of node-specific variations. For instance, use a fixed capacity scaling based on eval_capacity_factor as a fallback when drop_tokens=False.
We recommend implementing these changes to resolve the issue and prevent shape mismatches during collective communication. Please let us know if further assistance is needed.
Describe the bug
When I enable
drop_tokens
fordeepspeed.moe.layer.MoE
the program hangs. If I useeval_capacity_factor
it works but I would not like to overshoot the capacity factor to prevent token dropping, as I have noticed that deepspeed takes memory accordingly to the capacity factor set.To Reproduce
Steps to reproduce the behavior:
deepspeed --num_gpus 8 script.py
Expected behavior
I expect that the program runs.
ds_report output
System info (please complete the following information):
Docker context
nvcr.io/nvidia/pytorch:24.06-py3
The text was updated successfully, but these errors were encountered: