Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fp16 Qwen2 series model to DeepSpeed-FastGen #6028

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
for checkpoint in self._all_ckpt_paths:
inference_logger().info(f"Loading checkpoint: {checkpoint}")
checkpoint_sd = self._checkpoint_load_fn(checkpoint)

# If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights
if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings:
if self.model_config.model_type == "qwen2":
checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"]

param_keys = list(checkpoint_sd.keys())
for param_name in param_keys:
param = checkpoint_sd[param_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
} else if (8 == N_TOP_K) { \
constexpr int CONST_TOP_K = 8; \
__VA_ARGS__(); \
} \
}()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 96, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]
supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int,
theta_base: float) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ void launch_kv_rotary_kernel(T* kv_cache,
LAUNCH_KV_ROTARY_FOR_Q_RATIO(2)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(4)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(5)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(6)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(7)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(8)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(16)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(29)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,45 @@
from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Qwen1.5-MoE-A2.7B model looks like this:
# HF Qwen2-57B-A14B model looks like this:

Qwen2MoeForCausalLM(
(model): Qwen2MoeModel(
(embed_tokens): Embedding(151936, 2048)
(embed_tokens): Embedding(151936, 3584)
(layers): ModuleList(
(0-23): 24 x Qwen2MoeDecoderLayer(
(0-27): 28 x Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2MoeRotaryEmbedding()
)
(mlp): Qwen2MoeSparseMoeBlock(
(gate): Linear(in_features=2048, out_features=60, bias=False)
(gate): Linear(in_features=3584, out_features=64, bias=False)
(experts): ModuleList(
(0-59): 60 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
(0-63): 64 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=3584, out_features=2560, bias=False)
(up_proj): Linear(in_features=3584, out_features=2560, bias=False)
(down_proj): Linear(in_features=2560, out_features=3584, bias=False)
(act_fn): SiLU()
)
)
(shared_expert): Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(gate_proj): Linear(in_features=3584, out_features=20480, bias=False)
(up_proj): Linear(in_features=3584, out_features=20480, bias=False)
(down_proj): Linear(in_features=20480, out_features=3584, bias=False)
(act_fn): SiLU()
)
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
(shared_expert_gate): Linear(in_features=3584, out_features=1, bias=False)
)
(input_layernorm): Qwen2MoeRMSNorm()
(post_attention_layernorm): Qwen2MoeRMSNorm()
(input_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
(post_attention_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
)
)
(norm): Qwen2MoeRMSNorm()
(norm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
)
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
(lm_head): Linear(in_features=3584, out_features=151936, bias=False)
)
'''

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def n_heads(self) -> int:

@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size
return self._config.shared_expert_intermediate_size

@property
def n_heads_kv(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool:
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
return False

if config.top_k != 1 and config.top_k != 2 and config.top_k != 4:
if config.top_k != 1 and config.top_k != 2 and config.top_k != 4 and config.top_k != 8:
return False

return True
Expand Down