diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 46a84c61f884..874a66a1e2fe 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -108,6 +108,12 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: for checkpoint in self._all_ckpt_paths: inference_logger().info(f"Loading checkpoint: {checkpoint}") checkpoint_sd = self._checkpoint_load_fn(checkpoint) + + # If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights + if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings: + if self.model_config.model_type == "qwen2": + checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"] + param_keys = list(checkpoint_sd.keys()) for param_name in param_keys: param = checkpoint_sd[param_name] diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h index 2cc430ccfe34..f5104f899d9c 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -14,5 +14,8 @@ } else if (4 == N_TOP_K) { \ constexpr int CONST_TOP_K = 4; \ __VA_ARGS__(); \ + } else if (8 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 8; \ + __VA_ARGS__(); \ } \ }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index 7e1ec1a13cb9..aacbec0bd3ae 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase): supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] supported_head_sizes = [64, 80, 96, 128] - supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] + supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int, theta_base: float) -> None: diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu index fbafece5ccf2..f7bc693eefee 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu @@ -265,6 +265,8 @@ void launch_kv_rotary_kernel(T* kv_cache, LAUNCH_KV_ROTARY_FOR_Q_RATIO(2) LAUNCH_KV_ROTARY_FOR_Q_RATIO(4) LAUNCH_KV_ROTARY_FOR_Q_RATIO(5) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(6) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(7) LAUNCH_KV_ROTARY_FOR_Q_RATIO(8) LAUNCH_KV_ROTARY_FOR_Q_RATIO(16) LAUNCH_KV_ROTARY_FOR_Q_RATIO(29) diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py index b4621257ff82..e499379da7e3 100644 --- a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py @@ -8,45 +8,45 @@ from ..common_parameters import * from ..layer_container_base import LayerContainer ''' - # HF Qwen1.5-MoE-A2.7B model looks like this: + # HF Qwen2-57B-A14B model looks like this: Qwen2MoeForCausalLM( (model): Qwen2MoeModel( - (embed_tokens): Embedding(151936, 2048) + (embed_tokens): Embedding(151936, 3584) (layers): ModuleList( - (0-23): 24 x Qwen2MoeDecoderLayer( + (0-27): 28 x Qwen2MoeDecoderLayer( (self_attn): Qwen2MoeSdpaAttention( - (q_proj): Linear(in_features=2048, out_features=2048, bias=True) - (k_proj): Linear(in_features=2048, out_features=2048, bias=True) - (v_proj): Linear(in_features=2048, out_features=2048, bias=True) - (o_proj): Linear(in_features=2048, out_features=2048, bias=False) + (q_proj): Linear(in_features=3584, out_features=3584, bias=True) + (k_proj): Linear(in_features=3584, out_features=512, bias=True) + (v_proj): Linear(in_features=3584, out_features=512, bias=True) + (o_proj): Linear(in_features=3584, out_features=3584, bias=False) (rotary_emb): Qwen2MoeRotaryEmbedding() ) (mlp): Qwen2MoeSparseMoeBlock( - (gate): Linear(in_features=2048, out_features=60, bias=False) + (gate): Linear(in_features=3584, out_features=64, bias=False) (experts): ModuleList( - (0-59): 60 x Qwen2MoeMLP( - (gate_proj): Linear(in_features=2048, out_features=1408, bias=False) - (up_proj): Linear(in_features=2048, out_features=1408, bias=False) - (down_proj): Linear(in_features=1408, out_features=2048, bias=False) + (0-63): 64 x Qwen2MoeMLP( + (gate_proj): Linear(in_features=3584, out_features=2560, bias=False) + (up_proj): Linear(in_features=3584, out_features=2560, bias=False) + (down_proj): Linear(in_features=2560, out_features=3584, bias=False) (act_fn): SiLU() ) ) (shared_expert): Qwen2MoeMLP( - (gate_proj): Linear(in_features=2048, out_features=5632, bias=False) - (up_proj): Linear(in_features=2048, out_features=5632, bias=False) - (down_proj): Linear(in_features=5632, out_features=2048, bias=False) + (gate_proj): Linear(in_features=3584, out_features=20480, bias=False) + (up_proj): Linear(in_features=3584, out_features=20480, bias=False) + (down_proj): Linear(in_features=20480, out_features=3584, bias=False) (act_fn): SiLU() ) - (shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False) + (shared_expert_gate): Linear(in_features=3584, out_features=1, bias=False) ) - (input_layernorm): Qwen2MoeRMSNorm() - (post_attention_layernorm): Qwen2MoeRMSNorm() + (input_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06) + (post_attention_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06) ) ) - (norm): Qwen2MoeRMSNorm() + (norm): Qwen2MoeRMSNorm((3584,), eps=1e-06) ) - (lm_head): Linear(in_features=2048, out_features=151936, bias=False) + (lm_head): Linear(in_features=3584, out_features=151936, bias=False) ) ''' diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py index 7cddbf978369..c7841b24e5fc 100644 --- a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py @@ -73,7 +73,7 @@ def n_heads(self) -> int: @property def intermediate_dim(self) -> int: - return self._config.intermediate_size + return self._config.shared_expert_intermediate_size @property def n_heads_kv(self) -> int: diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py index bd90cbd5d697..a9b01d1233cd 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -42,7 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool: if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: return False - if config.top_k != 1 and config.top_k != 2 and config.top_k != 4: + if config.top_k != 1 and config.top_k != 2 and config.top_k != 4 and config.top_k != 8: return False return True