Skip to content

Commit

Permalink
Add fp16 support of Qwen1.5MoE models (A2.7B) to DeepSpeed-FastGen (#…
Browse files Browse the repository at this point in the history
…5403)

This PR adds support for Qwen1.5MoE-A2.7B models.

support for microsoft/DeepSpeed-MII#457

### Test Code

for mii pipeline:
```python
import mii

pipe = mii.pipeline("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
responses = pipe("DeepSpeed is", max_new_tokens=128, do_sample=False)
if pipe.is_rank_0:
    print(responses[0])
```
for huggingface:
```python
import mii

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
tokenizer = AutoTokenizer.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
model = AutoModelForCausalLM.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).eval()
print(model)
inputs = tokenizer('DeepSpeed is', return_tensors='pt')
inputs = inputs.to(model.device)
pred = model.generate(**inputs, max_new_tokens=128, do_sample=False, repetition_penalty=1.0)
test = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
print(test)
```

### Qwen1.5-MoE-A2.7B
Huggingface output with prompt "DeepSpeed is":
```
 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the
```
DeepSpeed-FastGen output with prompt "DeepSpeed is":
```
 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the
```

DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding:
```
 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the
```

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Heyang Qin <[email protected]>
Co-authored-by: Abhishek Kulkarni <[email protected]>
  • Loading branch information
4 people authored Aug 1, 2024
1 parent 23d0e02 commit 249c1db
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 1 deletion.
2 changes: 2 additions & 0 deletions blogs/deepspeed-fastgen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ We currently support the following model architectures in this alpha release of
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Phi-3](https://huggingface.co/models?other=phi3)
* [Qwen](https://huggingface.co/models?other=qwen)
* [Qwen2](https://huggingface.co/models?other=qwen2)
* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe)

All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Phi3Policy,
QwenPolicy,
Qwen2Policy,
Qwen2MoePolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
Expand Down Expand Up @@ -126,6 +127,8 @@ def build_hf_engine(path: str,
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2_moe":
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@
} else if (2 == N_TOP_K) { \
constexpr int CONST_TOP_K = 2; \
__VA_ARGS__(); \
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
} \
}()
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .phi3 import *
from .qwen import *
from .qwen_v2 import *
from .qwen_v2_moe import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .policy import Qwen2MoePolicy
103 changes: 103 additions & 0 deletions deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors using the policy file above.

from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Qwen1.5-MoE-A2.7B model looks like this:
Qwen2MoeForCausalLM(
(model): Qwen2MoeModel(
(embed_tokens): Embedding(151936, 2048)
(layers): ModuleList(
(0-23): 24 x Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2MoeRotaryEmbedding()
)
(mlp): Qwen2MoeSparseMoeBlock(
(gate): Linear(in_features=2048, out_features=60, bias=False)
(experts): ModuleList(
(0-59): 60 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
(act_fn): SiLU()
)
)
(shared_expert): Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(act_fn): SiLU()
)
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
)
(input_layernorm): Qwen2MoeRMSNorm()
(post_attention_layernorm): Qwen2MoeRMSNorm()
)
)
(norm): Qwen2MoeRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
)
'''


class Qwen2MoeTransformerContainer(LayerContainer):
"""
Transformer layer container for the Qwen2Moe model.
"""
qkv_w: UnfusedQKVParameter
qkv_b: UnfusedQKVParameter
attn_out_w: AttentionOutputParameter
moe_gate: MoEGatingWeightParameter
moe_mlp_1: UnfusedMoEGatedMLPParameter
moe_mlp_2: UnfusedMoEMLP2Parameter
shared_moe_mlp_1: GatedMLPParameter
shared_moe_mlp_2: MLP2Parameter
shared_moe_gate: MoEGatingWeightParameter
attn_norm_gamma: NormParameter
mlp_norm_gamma: NormParameter

PARAM_MAPPING = {
"self_attn.q_proj.weight": "qkv_w.q_params",
"self_attn.k_proj.weight": "qkv_w.k_params",
"self_attn.v_proj.weight": "qkv_w.v_params",
"self_attn.q_proj.bias": "qkv_b.q_params",
"self_attn.k_proj.bias": "qkv_b.k_params",
"self_attn.v_proj.bias": "qkv_b.v_params",
"self_attn.o_proj.weight": "attn_out_w.params",
"mlp.gate.weight": "moe_gate.params",
"mlp.experts.*.gate_proj.weight": "moe_mlp_1.gating_experts",
"mlp.experts.*.up_proj.weight": "moe_mlp_1.up_experts",
"mlp.experts.*.down_proj.weight": "moe_mlp_2.experts",
"mlp.shared_expert.gate_proj.weight": "shared_moe_mlp_1.gate_params",
"mlp.shared_expert.up_proj.weight": "shared_moe_mlp_1.up_params",
"mlp.shared_expert.down_proj.weight": "shared_moe_mlp_2.params",
"mlp.shared_expert_gate.weight": "shared_moe_gate.params",
"input_layernorm.weight": "attn_norm_gamma.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
}


class Qwen2MoeNonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Qwen2Moe model.
"""
word_emb: EmbeddingParameter
word_unembed: UnembedParameter
final_norm: NormParameter

PARAM_MAPPING = {
"model.embed_tokens.weight": "word_emb.params",
"model.norm.weight": "final_norm.params",
"lm_head.weight": "word_unembed.params",
}
Loading

0 comments on commit 249c1db

Please sign in to comment.