diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 9281640f844a..c320108f55e5 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -21,6 +21,7 @@ FalconPolicy, PhiPolicy, QwenPolicy, + Qwen2Policy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata @@ -120,6 +121,8 @@ def build_hf_engine(path: str, policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen": policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "qwen2": + policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine) else: raise ValueError(f"Unsupported model type {model_config.model_type}") diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index 869c4316cdc7..14b0654a8c36 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -16,3 +16,4 @@ from .falcon import * from .phi import * from .qwen import * +from .qwen_v2 import * diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py b/deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py new file mode 100644 index 000000000000..80b09757c74d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Qwen2Policy diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2/container.py new file mode 100644 index 000000000000..6556d87d6afb --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/container.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Qwen2 model looks like this: + +Qwen2ForCausalLM( + (model): Qwen2Model( + (embed_tokens): Embedding(151936, 1024) + (layers): ModuleList( + (0-23): 24 x Qwen2DecoderLayer( + (self_attn): Qwen2SdpaAttention( + (q_proj): Linear(in_features=1024, out_features=1024, bias=True) + (k_proj): Linear(in_features=1024, out_features=1024, bias=True) + (v_proj): Linear(in_features=1024, out_features=1024, bias=True) + (o_proj): Linear(in_features=1024, out_features=1024, bias=False) + (rotary_emb): Qwen2RotaryEmbedding() + ) + (mlp): Qwen2MLP( + (gate_proj): Linear(in_features=1024, out_features=2816, bias=False) + (up_proj): Linear(in_features=1024, out_features=2816, bias=False) + (down_proj): Linear(in_features=2816, out_features=1024, bias=False) + (act_fn): SiLU() + ) + (input_layernorm): Qwen2RMSNorm() + (post_attention_layernorm): Qwen2RMSNorm() + ) + ) + (norm): Qwen2RMSNorm() + ) + (lm_head): Linear(in_features=1024, out_features=151936, bias=False) +) +''' + + +class Qwen2TransformerContainer(LayerContainer): + """ + Transformer layer container for the Qwen2 model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class Qwen2NonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Qwen2 model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2/model.py new file mode 100644 index 000000000000..d535462a954d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/model.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper + +from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer + + +class Qwen2InferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Qwen2NonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Qwen2TransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + # TODO(ZonePG): bf16 inference results may be different from huggingface bf16, + # because in rms_norm, Qwen still use float() instead of bf16 + # if self._config.torch_dtype == torch.float16: + # return DtypeEnum.fp16 + # elif self._config.torch_dtype == torch.bfloat16: + # return DtypeEnum.bf16 + # else: + # raise NotImplementedError("Only fp16 and bf16 are supported") + return DtypeEnum.fp16 + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.SiGLU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + eps=self._config.rms_norm_eps, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/policy.py b/deepspeed/inference/v2/model_implementations/qwen_v2/policy.py new file mode 100644 index 000000000000..9c5db2ba0065 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer +from .model import Qwen2InferenceModel + + +class Qwen2Policy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2InferenceModel: + return Qwen2InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Qwen2TransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Qwen2NonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5cf655d8741a..142259c1b7df 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1635,19 +1635,16 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False): secondary_end = secondary_start + secondary_partition_size one_dim_param = param.contiguous().view(-1) - start = partition_size * self.rank - end = start + partition_size - if start < param.ds_numel and end <= param.ds_numel: - if secondary_start < param.ds_numel and secondary_end <= param.ds_numel: - sec_src_tensor = one_dim_param.narrow(0, secondary_start, secondary_partition_size) - param.ds_secondary_tensor.copy_(sec_src_tensor) - else: - if start < param.ds_numel: - elements_to_copy = param.ds_numel - start - elements_to_copy_sec = elements_to_copy * param.ds_secondary_tensor_num_of_groups - param.ds_secondary_tensor.narrow(0, 0, elements_to_copy_sec).copy_( - one_dim_param.narrow(0, secondary_start, elements_to_copy_sec)) + # ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size + sec_numel = param.ds_numel - secondary_start if secondary_end > param.ds_numel else secondary_partition_size + + # copy from full tensor to secondary tensor + param.ds_secondary_tensor.narrow(0, 0, + sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) + + # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done + get_accelerator().current_stream().synchronize() print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}", force=False)