Skip to content

Commit

Permalink
fix fp16 Qwen2 series model to DeepSpeed-FastGen (#6028)
Browse files Browse the repository at this point in the history
based on PR #5403 (Qwen1.5-MOE) and #5219 (Qwen1.5), support Qwen2
series model.

including: 0.5B, 1.5B, 7B, 57B-A14B, and 72B models.

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
ZonePG and loadams authored Aug 21, 2024
1 parent 7260890 commit e6fcc22
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 23 deletions.
6 changes: 6 additions & 0 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
for checkpoint in self._all_ckpt_paths:
inference_logger().info(f"Loading checkpoint: {checkpoint}")
checkpoint_sd = self._checkpoint_load_fn(checkpoint)

# If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights
if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings:
if self.model_config.model_type == "qwen2":
checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"]

param_keys = list(checkpoint_sd.keys())
for param_name in param_keys:
param = checkpoint_sd[param_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
} else if (8 == N_TOP_K) { \
constexpr int CONST_TOP_K = 8; \
__VA_ARGS__(); \
} \
}()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 96, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]
supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int,
theta_base: float) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ void launch_kv_rotary_kernel(T* kv_cache,
LAUNCH_KV_ROTARY_FOR_Q_RATIO(2)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(4)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(5)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(6)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(7)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(8)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(16)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(29)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,45 @@
from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Qwen1.5-MoE-A2.7B model looks like this:
# HF Qwen2-57B-A14B model looks like this:
Qwen2MoeForCausalLM(
(model): Qwen2MoeModel(
(embed_tokens): Embedding(151936, 2048)
(embed_tokens): Embedding(151936, 3584)
(layers): ModuleList(
(0-23): 24 x Qwen2MoeDecoderLayer(
(0-27): 28 x Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2MoeRotaryEmbedding()
)
(mlp): Qwen2MoeSparseMoeBlock(
(gate): Linear(in_features=2048, out_features=60, bias=False)
(gate): Linear(in_features=3584, out_features=64, bias=False)
(experts): ModuleList(
(0-59): 60 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
(0-63): 64 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=3584, out_features=2560, bias=False)
(up_proj): Linear(in_features=3584, out_features=2560, bias=False)
(down_proj): Linear(in_features=2560, out_features=3584, bias=False)
(act_fn): SiLU()
)
)
(shared_expert): Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(gate_proj): Linear(in_features=3584, out_features=20480, bias=False)
(up_proj): Linear(in_features=3584, out_features=20480, bias=False)
(down_proj): Linear(in_features=20480, out_features=3584, bias=False)
(act_fn): SiLU()
)
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
(shared_expert_gate): Linear(in_features=3584, out_features=1, bias=False)
)
(input_layernorm): Qwen2MoeRMSNorm()
(post_attention_layernorm): Qwen2MoeRMSNorm()
(input_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
(post_attention_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
)
)
(norm): Qwen2MoeRMSNorm()
(norm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
)
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
(lm_head): Linear(in_features=3584, out_features=151936, bias=False)
)
'''

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def n_heads(self) -> int:

@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size
return self._config.shared_expert_intermediate_size

@property
def n_heads_kv(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool:
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
return False

if config.top_k != 1 and config.top_k != 2 and config.top_k != 4:
if config.top_k != 1 and config.top_k != 2 and config.top_k != 4 and config.top_k != 8:
return False

return True
Expand Down

0 comments on commit e6fcc22

Please sign in to comment.