Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Enabling drop_tokens in MoE layer causes inference to hang #6809

Open
Shamauk opened this issue Nov 29, 2024 · 1 comment
Open

[BUG] Enabling drop_tokens in MoE layer causes inference to hang #6809

Shamauk opened this issue Nov 29, 2024 · 1 comment
Labels
bug Something isn't working inference

Comments

@Shamauk
Copy link

Shamauk commented Nov 29, 2024

Describe the bug
When I enable drop_tokens for deepspeed.moe.layer.MoE the program hangs. If I use eval_capacity_factor it works but I would not like to overshoot the capacity factor to prevent token dropping, as I have noticed that deepspeed takes memory accordingly to the capacity factor set.

To Reproduce

# https://proceedings.mlr.press/v162/rajbhandari22a/rajbhandari22a.pdf
# https://github.com/microsoft/DeepSpeed

import torch
import torch.nn as nn
import sys
import os 
import time
import csv
from tqdm import tqdm
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
import argparse 
import deepspeed

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int) 
parser.add_argument("--world_size", default=8, type=int)
args = parser.parse_args()

def setup():
    os.environ["HF_HOME"] = "/cache"
    os.environ["HF_DATASETS_CACHE"] = "/cache"
    os.environ["TRITON_HOME"] = "/.triton"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    deepspeed.init_distributed()

    torch.cuda.set_device(args.local_rank)

def run_inference_workload():
    setup()

    model_name = f"google/switch-base-128"
    model = AutoModel.from_pretrained(model_name, cache_dir="/cache")

    class MLPWrapper(nn.Module):
        def __init__(self, child):
            super().__init__()
            self.child = child
        
        def forward(self, x):
            x = self.child(x)
            return x[0], (x[1], x[2])

    # Update to add DeepspeedMoE to it
    def add_deepspeed_moe_model(module):
        if type(module).__name__ == "SwitchTransformersLayerFF":
            for child_name, child in module.named_children():
                if type(child).__name__ == "SwitchTransformersSparseMLP":
                    router = getattr(child, "router")
                    experts = getattr(child, "experts")
                    if type(experts) == nn.ModuleDict:
                        experts = list(experts.values())
                    
                    num_experts_per_gpu = 128 // args.world_size

                    experts = experts[args.local_rank*num_experts_per_gpu:(args.local_rank+1)*num_experts_per_gpu]

                    new = deepspeed.moe.layer.MoE(
                        hidden_size=768,
                        expert=experts[0],
                        num_experts=128,
                        ep_size=args.world_size,
                        k=1,
                        eval_capacity_factor=10.0,
                        # drop_tokens=False,
                        use_tutel=True,
                        top2_2nd_expert_sampling=False,
                        use_rts=False,
                    )

                    with torch.no_grad():
                        new.deepspeed_moe.gate.wg.weight.copy_(router.classifier.weight)
                        for i in range(len(experts)):
                            new.deepspeed_moe.experts.deepspeed_experts[i].wi.weight.copy_(experts[i].wi.weight)
                            new.deepspeed_moe.experts.deepspeed_experts[i].wo.weight.copy_(experts[i].wo.weight)

                    setattr(module, child_name, MLPWrapper(new))
        else:
            for child in module.children():
                add_deepspeed_moe_model(child)

    add_deepspeed_moe_model(model)

    model.eval()
    model.cuda()

    ds_engine = deepspeed.init_inference(
        model,
        dtype=torch.float,
        replace_with_kernel_inject=False,
        moe={
            "enabled": True,
            "ep_size": args.world_size, 
            "moe_experts": [128],
        }
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/cache")

    dataset = load_dataset(
        "bookcorpus/bookcorpus", 
        split=f"train[:3200]", 
        streaming=False, 
        trust_remote_code=True, 
        cache_dir="/cache"
    )
    sampler = DistributedSampler(
        dataset, 
        num_replicas=args.world_size, 
        rank=args.local_rank, 
        shuffle=True, 
        seed=49
    )
    def collate_fn(batch):
        texts = ["summarize: " + item["text"] for item in batch]
        
        tokenized = tokenizer(
            texts, 
            padding=True, 
            truncation=True, 
            max_length=120,
            return_tensors="pt"
        )
        return {
            **tokenized,
            "decoder_input_ids": torch.tensor([[tokenizer.pad_token_id]]*len(batch))
        }
    loader = DataLoader(
        dataset, 
        sampler=sampler, 
        batch_size=100,
        collate_fn=collate_fn,
    )

    run_standard_experiment(ds_engine, loader)


def run_standard_experiment(ds_engine, loader):    
    with torch.no_grad():
        for batch in tqdm(loader):
            batch = {k: v.cuda() for k, v in batch.items()}
            ds_engine(
                input_ids=batch["input_ids"], 
                attention_mask=batch["attention_mask"],
                decoder_input_ids=batch["decoder_input_ids"],
            )

if __name__ == "__main__":
    run_inference_workload()

    print("All done :)")

Steps to reproduce the behavior:

  1. deepspeed --num_gpus 8 script.py
  2. It will run no problem but if you uncomment the drop_tokens, the program halts

Expected behavior
I expect that the program runs.

ds_report output

[2024-11-29 16:34:46,436] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
 [WARNING]  gds requires the dev libaio .so object and headers but these were not found.
 [WARNING]  gds: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.4.0a0+f70bd71a48.nv24.06
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.15.3, unknown, unknown
torch cuda version ............... 12.5
torch hip version ................ None
nvcc version ..................... 12.5
deepspeed wheel compiled w. ...... torch 2.4, cuda 12.5
shared memory (/dev/shm) size .... 251.87 GB

System info (please complete the following information):

  • Ubuntu 22.04
  • 1 Node DGX1 8xV100
  • Transformers 4.46.0
  • DeepSpeed 0.15.3
  • Python 3.10.12

Docker context
nvcr.io/nvidia/pytorch:24.06-py3

@Shamauk Shamauk added bug Something isn't working inference labels Nov 29, 2024
@traincheck-team
Copy link

traincheck-team commented Dec 6, 2024

The pipeline seems to be stuck at all_to_all communication, and the problem is due to inconsistent tensor shapes as a result of inconsistent capacity during inference across workers.

We used our tool to investigate potential anomalies in tensor shapes during the _AllToAll.apply operations for both dispatched_input and expert_output. Debugging statements inserted before and after the _AllToAll.apply calls revealed the issue.

Below is the debugging output (with num_gpus=2):

[DEBUG] Before ALL-to-ALL: dispatched_input shape torch.Size([4, 3, 32])
[DEBUG] Before ALL-to-ALL: dispatched_input shape torch.Size([4, 4, 32])
[DEBUG] After ALL-to-ALL: dispatched_input shape torch.Size([4, 3, 32]) # one process stalls here
[DEBUG] Before ALL-to-ALL: expert_output shape torch.Size([4, 3, 32]) # both processes stalls afterwards

The dispatched_input tensor shows inconsistent shapes across processes during the _AllToAll operation, suggesting that drop_tokens=False leads to mismatched tensor sizes. This inconsistency causes the program to silently stall.

Detailed Findings

Further investigation into the topXGating function revealed that:

  • When

    drop_tokens=False
    

    , the capacity is dynamically computed as:

    new_capacity = torch.max(exp_counts).to(logits.device)
    • Here, exp_counts represents the number of tokens assigned to each expert.
    • Since exp_counts depends on the logits distribution, which can vary across nodes due to slight differences in computation or rounding, the resulting new_capacity is inconsistent across processes.

This dynamic and non-deterministic capacity calculation leads to shape mismatches in tensors sent to the _AllToAll collective communication, causing the program to hang.


Suggested Fix

To ensure consistent behavior across nodes when drop_tokens=False:

  1. Synchronize new_capacity across all nodes:

    • Replace the dynamic capacity computation with a collective operation to ensure all processes use the same

      new_capacity
      

      value:

      new_capacity = torch.max(exp_counts).to(logits.device)
      new_capacity = torch.distributed.all_reduce(new_capacity, op=torch.distributed.ReduceOp.MAX)
  2. Explicit Synchronization in _AllToAll Calls:

    • Add a step to synchronize tensor shapes before dispatching them to _AllToAll to ensure consistency.
  3. Alternative Approach:

    • Consider reworking the logic in topXGating to guarantee deterministic new_capacity calculation independent of node-specific variations. For instance, use a fixed capacity scaling based on eval_capacity_factor as a fallback when drop_tokens=False.

We recommend implementing these changes to resolve the issue and prevent shape mismatches during collective communication. Please let us know if further assistance is needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

2 participants